1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/graph/graph.h"
17
18 #include <memory>
19 #include <vector>
20
21 #include "absl/container/flat_hash_map.h"
22 #include "tensorflow/core/framework/full_type.pb.h"
23 #include "tensorflow/core/framework/graph.pb.h"
24 #include "tensorflow/core/framework/node_def.pb.h"
25 #include "tensorflow/core/framework/node_properties.h"
26 #include "tensorflow/core/framework/op_def_builder.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/versions.pb.h"
29 #include "tensorflow/core/graph/graph_node_util.h"
30 #include "tensorflow/core/graph/while_context.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/gtl/map_util.h"
33 #include "tensorflow/core/lib/hash/hash.h"
34 #include "tensorflow/core/lib/strings/strcat.h"
35 #include "tensorflow/core/lib/strings/stringprintf.h"
36 #include "tensorflow/core/platform/errors.h"
37 #include "tensorflow/core/platform/logging.h"
38 #include "tensorflow/core/public/version.h"
39
40 namespace tensorflow {
41
42 const int Graph::kControlSlot = -1;
43
44 // Node
GetNodeClassForOp(const std::string & ts)45 Node::NodeClass Node::GetNodeClassForOp(const std::string& ts) {
46 static const absl::flat_hash_map<std::string, Node::NodeClass>*
47 kNodeClassTable =
48 #define REF_CLASS(key, value) \
49 {key, value}, { "Ref" key, value }
50 new absl::flat_hash_map<std::string, Node::NodeClass>({
51 // Keep in same order as NodeClass values
52 REF_CLASS("Switch", NC_SWITCH),
53 REF_CLASS("_SwitchN", NC_SWITCH),
54 REF_CLASS("Merge", NC_MERGE),
55 REF_CLASS("Enter", NC_ENTER),
56 REF_CLASS("Exit", NC_EXIT),
57 REF_CLASS("NextIteration", NC_NEXT_ITERATION),
58 {"LoopCond", NC_LOOP_COND},
59 {"ControlTrigger", NC_CONTROL_TRIGGER},
60 {"_Send", NC_SEND},
61 {"_HostSend", NC_HOST_SEND},
62 {"_Recv", NC_RECV},
63 {"_HostRecv", NC_HOST_RECV},
64 {"Const", NC_CONSTANT},
65 {"HostConst", NC_CONSTANT},
66 {"Variable", NC_VARIABLE},
67 {"VariableV2", NC_VARIABLE},
68 REF_CLASS("Identity", NC_IDENTITY),
69 {"GetSessionHandle", NC_GET_SESSION_HANDLE},
70 {"GetSessionHandleV2", NC_GET_SESSION_HANDLE},
71 {"GetSessionTensor", NC_GET_SESSION_TENSOR},
72 {"DeleteSessionTensor", NC_DELETE_SESSION_TENSOR},
73 {"Size", NC_METADATA},
74 {"Shape", NC_METADATA},
75 {"Rank", NC_METADATA},
76 {"_ScopedAllocator", NC_SCOPED_ALLOCATOR},
77 {"CollectiveReduce", NC_COLLECTIVE},
78 {"CollectiveBcastSend", NC_COLLECTIVE},
79 {"CollectiveBcastRecv", NC_COLLECTIVE},
80 {"CollectiveGather", NC_COLLECTIVE},
81 {"FakeParam", NC_FAKE_PARAM},
82 {"PartitionedCall", NC_PARTITIONED_CALL},
83 {"StatefulPartitionedCall", NC_PARTITIONED_CALL},
84 {"SymbolicGradient", NC_SYMBOLIC_GRADIENT},
85 {"If", NC_IF},
86 {"StatelessIf", NC_IF},
87 {"While", NC_WHILE},
88 {"StatelessWhile", NC_WHILE},
89 {"Case", NC_CASE},
90 {"StatelessCase", NC_CASE},
91 // Not using the constants defined in FunctionLibraryDefinition
92 // for the
93 // 4 ops below because android inference library does not link
94 // tf.function related files.
95 {"_Arg", NC_ARG},
96 {"_DeviceArg", NC_ARG},
97 {"_Retval", NC_RETVAL},
98 {"_DeviceRetval", NC_RETVAL},
99 {"_XlaMerge", NC_MERGE},
100 });
101 #undef REF_CLASS
102
103 auto it = kNodeClassTable->find(ts);
104 if (it != kNodeClassTable->end()) {
105 return it->second;
106 } else {
107 return NC_OTHER;
108 }
109 }
110
DebugString() const111 std::string Node::DebugString() const {
112 std::string ret = strings::StrCat("{name:'", name(), "' id:", id_);
113 if (IsSource()) {
114 strings::StrAppend(&ret, " source}");
115 } else if (IsSink()) {
116 strings::StrAppend(&ret, " sink}");
117 } else {
118 strings::StrAppend(&ret, " op device:", "{requested: '", requested_device(),
119 "', assigned: '", assigned_device_name(), "'}", " def:{",
120 SummarizeNode(*this), "}}");
121 }
122 return ret;
123 }
124
Node()125 Node::Node()
126 : id_(-1),
127 cost_id_(-1),
128 class_(NC_UNINITIALIZED),
129 props_(nullptr),
130 assigned_device_name_index_(0),
131 while_ctx_(nullptr) {}
132
Initialize(int id,int cost_id,std::shared_ptr<NodeProperties> props,Node::NodeClass node_class)133 void Node::Initialize(int id, int cost_id,
134 std::shared_ptr<NodeProperties> props,
135 Node::NodeClass node_class) {
136 DCHECK_EQ(id_, -1);
137 DCHECK(in_edges_.empty());
138 DCHECK(out_edges_.empty());
139 id_ = id;
140 cost_id_ = cost_id;
141
142 props_ = std::move(props);
143 class_ = node_class;
144 }
145
Clear()146 void Node::Clear() {
147 in_edges_.clear();
148 out_edges_.clear();
149 id_ = -1;
150 cost_id_ = -1;
151 class_ = NC_UNINITIALIZED;
152 props_.reset();
153 assigned_device_name_index_ = 0;
154 }
155
UpdateProperties()156 void Node::UpdateProperties() {
157 DataTypeVector inputs;
158 DataTypeVector outputs;
159 Status status =
160 InOutTypesForNode(props_->node_def, *(props_->op_def), &inputs, &outputs);
161 if (!status.ok()) {
162 LOG(ERROR) << "Failed at updating node: " << status;
163 return;
164 }
165 if (props_->input_types != inputs || props_->output_types != outputs) {
166 if (TF_PREDICT_TRUE(props_.use_count() == 1)) {
167 props_->input_types = inputs;
168 props_->input_types_slice = props_->input_types;
169 props_->output_types = outputs;
170 props_->output_types_slice = props_->output_types;
171 } else {
172 props_ = std::make_shared<NodeProperties>(
173 props_->op_def, std::move(props_->node_def), inputs, outputs);
174 }
175 }
176 }
177
ClearTypeInfo()178 void Node::ClearTypeInfo() {
179 if (props_->node_def.has_experimental_type()) {
180 MaybeCopyOnWrite();
181 props_->node_def.clear_experimental_type();
182 }
183 }
184
name() const185 const std::string& Node::name() const { return props_->node_def.name(); }
type_string() const186 const std::string& Node::type_string() const { return props_->node_def.op(); }
def() const187 const NodeDef& Node::def() const { return props_->node_def; }
op_def() const188 const OpDef& Node::op_def() const { return *props_->op_def; }
189
mutable_def()190 NodeDef* Node::mutable_def() { return &props_->node_def; }
191
num_inputs() const192 int32 Node::num_inputs() const { return props_->input_types.size(); }
input_type(int32_t i) const193 DataType Node::input_type(int32_t i) const { return props_->input_types[i]; }
input_types() const194 const DataTypeVector& Node::input_types() const { return props_->input_types; }
195
num_outputs() const196 int32 Node::num_outputs() const { return props_->output_types.size(); }
output_type(int32_t o) const197 DataType Node::output_type(int32_t o) const { return props_->output_types[o]; }
output_types() const198 const DataTypeVector& Node::output_types() const {
199 return props_->output_types;
200 }
201
attrs() const202 AttrSlice Node::attrs() const { return AttrSlice(def()); }
203
requested_inputs() const204 const protobuf::RepeatedPtrField<std::string>& Node::requested_inputs() const {
205 return def().input();
206 }
207
requested_device() const208 const std::string& Node::requested_device() const { return def().device(); }
209
out_nodes() const210 gtl::iterator_range<NeighborIter> Node::out_nodes() const {
211 return gtl::make_range(NeighborIter(out_edges_.begin(), false),
212 NeighborIter(out_edges_.end(), false));
213 }
214
in_nodes() const215 gtl::iterator_range<NeighborIter> Node::in_nodes() const {
216 return gtl::make_range(NeighborIter(in_edges_.begin(), true),
217 NeighborIter(in_edges_.end(), true));
218 }
219
MaybeCopyOnWrite()220 void Node::MaybeCopyOnWrite() {
221 // TODO(mdan): As nodes become more dynamic, this may not be worth the cost.
222 // NodeProperties may be shared between Nodes. Make a copy if so.
223 if (!props_.unique()) {
224 props_ = std::make_shared<NodeProperties>(*props_);
225 }
226 }
227
AddAttrHelper(const std::string & name)228 AttrValue* Node::AddAttrHelper(const std::string& name) {
229 MaybeCopyOnWrite();
230 return &((*props_->node_def.mutable_attr())[name]);
231 }
232
ClearAttr(const std::string & name)233 void Node::ClearAttr(const std::string& name) {
234 MaybeCopyOnWrite();
235 (*props_->node_def.mutable_attr()).erase(name);
236 }
237
set_name(std::string name)238 void Node::set_name(std::string name) {
239 MaybeCopyOnWrite();
240 props_->node_def.set_name(std::move(name));
241 }
242
set_requested_device(const std::string & device)243 void Node::set_requested_device(const std::string& device) {
244 MaybeCopyOnWrite();
245 props_->node_def.set_device(device);
246 }
247
set_original_node_names(const std::vector<std::string> & names)248 void Node::set_original_node_names(const std::vector<std::string>& names) {
249 MaybeCopyOnWrite();
250 props_->node_def.mutable_experimental_debug_info()
251 ->clear_original_node_names();
252 if (!names.empty()) {
253 *props_->node_def.mutable_experimental_debug_info()
254 ->mutable_original_node_names() = {names.begin(), names.end()};
255 }
256 }
257
set_original_func_names(const std::vector<std::string> & names)258 void Node::set_original_func_names(const std::vector<std::string>& names) {
259 MaybeCopyOnWrite();
260 props_->node_def.mutable_experimental_debug_info()
261 ->clear_original_func_names();
262 if (!names.empty()) {
263 *props_->node_def.mutable_experimental_debug_info()
264 ->mutable_original_func_names() = {names.begin(), names.end()};
265 }
266 }
267
input_edge(int idx,const Edge ** e) const268 Status Node::input_edge(int idx, const Edge** e) const {
269 if (idx < 0 || idx >= num_inputs()) {
270 return errors::InvalidArgument("Invalid input_edge index: ", idx, ", Node ",
271 name(), " only has ", num_inputs(),
272 " inputs.");
273 }
274
275 // This does a linear search over the edges. In the common case,
276 // the number of elements is small enough that this search isn't
277 // expensive. Should it become a bottleneck, one can make an
278 // optimization where, if the number of edges is small, we use
279 // linear iteration, and if the number of edges is large, we perform
280 // an indexing step during construction that keeps an array of Edges
281 // indexed by pointer. This would keep the size of each Node small
282 // in the common case but make this function faster when the number
283 // of edges is large.
284 for (const Edge* edge : in_edges()) {
285 if (edge->dst_input() == idx) {
286 *e = edge;
287 return OkStatus();
288 }
289 }
290
291 return errors::NotFound("Could not find input edge ", idx, " for ", name());
292 }
293
294 // Returns a vector of the non-control input edges to a node, indexed by ID.
input_edges(std::vector<const Edge * > * input_edges) const295 Status Node::input_edges(std::vector<const Edge*>* input_edges) const {
296 input_edges->clear();
297 input_edges->resize(num_inputs(), nullptr);
298
299 for (const Edge* edge : in_edges()) {
300 if (edge->IsControlEdge()) continue;
301 if (edge->dst_input() < 0 || edge->dst_input() >= num_inputs()) {
302 return errors::Internal("Invalid edge input number ", edge->dst_input());
303 }
304 if ((*input_edges)[edge->dst_input()] != nullptr) {
305 return errors::Internal("Duplicate edge input number: ",
306 edge->dst_input());
307 }
308 (*input_edges)[edge->dst_input()] = edge;
309 }
310
311 for (int i = 0; i < num_inputs(); ++i) {
312 if ((*input_edges)[i] == nullptr) {
313 return errors::InvalidArgument("Missing edge input number: ", i);
314 }
315 }
316 return OkStatus();
317 }
318
input_node(int idx,Node ** n) const319 Status Node::input_node(int idx, Node** n) const {
320 const Edge* e;
321 TF_RETURN_IF_ERROR(input_edge(idx, &e));
322 if (e == nullptr) {
323 *n = nullptr;
324 } else {
325 *n = e->src();
326 }
327 return OkStatus();
328 }
329
input_node(int idx,const Node ** const_n) const330 Status Node::input_node(int idx, const Node** const_n) const {
331 Node* n;
332 TF_RETURN_IF_ERROR(input_node(idx, &n));
333 *const_n = n;
334 return OkStatus();
335 }
336
input_tensor(int idx,OutputTensor * t) const337 Status Node::input_tensor(int idx, OutputTensor* t) const {
338 const Edge* e;
339 TF_RETURN_IF_ERROR(input_edge(idx, &e));
340 DCHECK(e != nullptr);
341 *t = OutputTensor(e->src(), e->src_output());
342 return OkStatus();
343 }
344
345 // NodeDebugInfo
346
NodeDebugInfo(const Node & n)347 NodeDebugInfo::NodeDebugInfo(const Node& n) : NodeDebugInfo(n.def()) {}
NodeDebugInfo(const NodeDef & ndef)348 NodeDebugInfo::NodeDebugInfo(const NodeDef& ndef)
349 : NodeDebugInfo(ndef.name(), ndef.has_experimental_debug_info(),
350 ndef.experimental_debug_info()) {}
NodeDebugInfo(StringPiece node_name,bool has_experimental_debug_info,const NodeDef_ExperimentalDebugInfo & experimental_debug_info)351 NodeDebugInfo::NodeDebugInfo(
352 StringPiece node_name, bool has_experimental_debug_info,
353 const NodeDef_ExperimentalDebugInfo& experimental_debug_info)
354 : name(node_name) {
355 if (has_experimental_debug_info) {
356 const auto& node_names = experimental_debug_info.original_node_names();
357 original_node_names.assign(node_names.begin(), node_names.end());
358 const auto& func_names = experimental_debug_info.original_func_names();
359 original_func_names.assign(func_names.begin(), func_names.end());
360 }
361 }
362 // InputTensor
363
operator ==(const InputTensor & other) const364 bool InputTensor::operator==(const InputTensor& other) const {
365 return node == other.node && index == other.index;
366 }
367
operator ()(InputTensor const & s) const368 uint64 InputTensor::Hash::operator()(InputTensor const& s) const {
369 return Hash64Combine(std::hash<const Node*>()(s.node),
370 std::hash<int>()(s.index));
371 }
372
373 // OutputTensor
374
operator ==(const OutputTensor & other) const375 bool OutputTensor::operator==(const OutputTensor& other) const {
376 return node == other.node && index == other.index;
377 }
378
operator ()(OutputTensor const & s) const379 uint64 OutputTensor::Hash::operator()(OutputTensor const& s) const {
380 return Hash64Combine(std::hash<const Node*>()(s.node),
381 std::hash<int>()(s.index));
382 }
383
384 // Graph
385
Graph(const OpRegistryInterface * ops)386 Graph::Graph(const OpRegistryInterface* ops)
387 : ops_(ops, FunctionDefLibrary()),
388 versions_(new VersionDef),
389 arena_(8 << 10 /* 8kB */) {
390 versions_->set_producer(TF_GRAPH_DEF_VERSION);
391 versions_->set_min_consumer(TF_GRAPH_DEF_VERSION_MIN_CONSUMER);
392
393 // Initialize the name interning table for assigned_device_name.
394 device_names_.push_back("");
395 DCHECK_EQ(0, InternDeviceName(""));
396
397 // Source and sink have no endpoints, just control edges.
398 NodeDef def;
399 def.set_name("_SOURCE");
400 def.set_op("NoOp");
401 Status status;
402 Node* source = AddNode(def, &status);
403 TF_CHECK_OK(status);
404 CHECK_EQ(source->id(), kSourceId);
405
406 def.set_name("_SINK");
407 Node* sink = AddNode(def, &status);
408 TF_CHECK_OK(status);
409 CHECK_EQ(sink->id(), kSinkId);
410
411 AddControlEdge(source, sink);
412 }
413
Graph(const FunctionLibraryDefinition & flib_def)414 Graph::Graph(const FunctionLibraryDefinition& flib_def)
415 : Graph(flib_def.default_registry()) {
416 // Need a new-enough consumer to support the functions we add to the graph.
417 if (flib_def.num_functions() > 0 && versions_->min_consumer() < 12) {
418 versions_->set_min_consumer(12);
419 }
420 Status s = ops_.AddLibrary(flib_def);
421 CHECK(s.ok()) << s.error_message();
422 }
423
~Graph()424 Graph::~Graph() {
425 // Manually call the destructors for all the Nodes we constructed using
426 // placement new.
427 for (Node* node : nodes_) {
428 if (node != nullptr) {
429 node->~Node();
430 }
431 }
432 for (Node* node : free_nodes_) {
433 node->~Node();
434 }
435 // Edges have no destructor, and we arena-allocated them, so no need to
436 // destroy them.
437 }
438
Clone()439 std::unique_ptr<Graph> Graph::Clone() {
440 std::unique_ptr<Graph> new_graph(new Graph(flib_def()));
441 new_graph->Copy(*this);
442 return new_graph;
443 }
444
Clear()445 void Graph::Clear() {
446 // Do a direct iteration clearing nodes removing the RemoveNode helper method.
447 // This could avoid this helper and clear directly if it becomes performance
448 // sensitive.
449 for (Node* n : nodes()) {
450 if (!n->IsSource() && !n->IsSink()) RemoveNode(n);
451 }
452 }
453
versions() const454 const VersionDef& Graph::versions() const { return *versions_; }
set_versions(const VersionDef & versions)455 void Graph::set_versions(const VersionDef& versions) { *versions_ = versions; }
456
Copy(const Graph & src)457 void Graph::Copy(const Graph& src) {
458 SetConstructionContext(src.GetConstructionContextInternal());
459 for (Node* n : nodes()) {
460 CHECK(n->IsSource() || n->IsSink()) << "*dest must be empty";
461 }
462
463 // Copy GraphDef versions
464 set_versions(src.versions());
465
466 // Copy the nodes.
467 // "Node in src" -> "Node in *dest"
468 gtl::FlatMap<const Node*, Node*> node_map;
469 node_map.reserve(src.num_nodes());
470 node_map[src.source_node()] = source_node();
471 node_map[src.sink_node()] = sink_node();
472 for (Node* n : src.op_nodes()) {
473 auto copy = CopyNode(n);
474 copy->in_edges_.reserve(n->in_edges().size());
475 copy->out_edges_.reserve(n->out_edges().size());
476 node_map[n] = copy;
477 }
478
479 // Copy the edges
480 edges_.reserve(src.num_edges());
481 for (const Edge* e : src.edges()) {
482 Node* src_copy = node_map[e->src()];
483 Node* dst_copy = node_map[e->dst()];
484 AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
485 }
486 }
487
AddNode(NodeDef node_def)488 StatusOr<Node*> Graph::AddNode(NodeDef node_def) {
489 Status s;
490 Node* out = AddNode(std::move(node_def), &s);
491 TF_RETURN_IF_ERROR(s);
492 return out;
493 }
494
AddNode(NodeDef node_def,Status * status)495 Node* Graph::AddNode(NodeDef node_def, Status* status) {
496 const OpRegistrationData* op_reg_data;
497 status->Update(ops_.LookUp(node_def.op(), &op_reg_data));
498 if (!status->ok()) return nullptr;
499
500 DataTypeVector inputs;
501 DataTypeVector outputs;
502 status->Update(
503 InOutTypesForNode(node_def, op_reg_data->op_def, &inputs, &outputs));
504 if (!status->ok()) {
505 *status = AttachDef(*status, node_def);
506 return nullptr;
507 }
508
509 Node::NodeClass node_class = op_reg_data->is_function_op
510 ? Node::NC_FUNCTION_OP
511 : Node::GetNodeClassForOp(node_def.op());
512
513 if (node_def.has_experimental_type()) {
514 VLOG(3) << "AddNode: node has type set, skipping type constructor "
515 << node_def.name();
516 } else {
517 if (op_reg_data->type_ctor != nullptr) {
518 VLOG(3) << "AddNode: found type constructor for " << node_def.name();
519 Status s =
520 full_type::SpecializeType(AttrSlice(node_def), op_reg_data->op_def,
521 *(node_def.mutable_experimental_type()));
522 if (!s.ok()) {
523 *status = errors::InvalidArgument("type error: ", s.ToString());
524 VLOG(3) << "AddNode: type inference failed for " << node_def.name()
525 << ": " << s;
526 return nullptr;
527 }
528 } else {
529 VLOG(3) << "AddNode: no type constructor for " << node_def.name();
530 }
531 }
532
533 Node* node = AllocateNode(
534 std::make_shared<NodeProperties>(&op_reg_data->op_def,
535 std::move(node_def), inputs, outputs),
536 nullptr, node_class);
537 return node;
538 }
539
CopyNode(const Node * node)540 Node* Graph::CopyNode(const Node* node) {
541 DCHECK(!node->IsSource());
542 DCHECK(!node->IsSink());
543 Node* copy = AllocateNode(node->props_, node, node->class_);
544 copy->set_assigned_device_name(node->assigned_device_name());
545
546 // Since the OpDef of a function may be owned by the Graph that owns 'node',
547 // relookup the OpDef in the target graph. If it differs, then clone the
548 // node properties with the updated OpDef.
549 const OpDef* op_def;
550 TF_CHECK_OK(ops_.LookUpOpDef(node->type_string(), &op_def));
551 if (op_def != node->props_->op_def) {
552 copy->MaybeCopyOnWrite();
553 copy->props_->op_def = op_def;
554 }
555 copy->SetStackTrace(node->GetStackTrace());
556
557 return copy;
558 }
559
RemoveNode(Node * node)560 void Graph::RemoveNode(Node* node) {
561 TF_DCHECK_OK(IsValidNode(node)) << node->DebugString();
562 DCHECK(!node->IsSource());
563 DCHECK(!node->IsSink());
564
565 // Remove any edges involving this node.
566 for (const Edge* e : node->in_edges_) {
567 CHECK_EQ(e->src_->out_edges_.erase(e), size_t{1});
568 edges_[e->id_] = nullptr;
569 RecycleEdge(e);
570 --num_edges_;
571 }
572 node->in_edges_.clear();
573 for (const Edge* e : node->out_edges_) {
574 CHECK_EQ(e->dst_->in_edges_.erase(e), size_t{1});
575 edges_[e->id_] = nullptr;
576 RecycleEdge(e);
577 --num_edges_;
578 }
579 node->out_edges_.clear();
580 ReleaseNode(node);
581 }
582
AddEdge(Node * source,int x,Node * dest,int y)583 const Edge* Graph::AddEdge(Node* source, int x, Node* dest, int y) {
584 TF_DCHECK_OK(IsValidNode(source)) << source->DebugString();
585 TF_DCHECK_OK(IsValidNode(dest)) << dest->DebugString();
586
587 // source/sink must only be linked via control slots, and
588 // control slots must only be linked to control slots.
589 if (source == source_node() || dest == sink_node() || x == kControlSlot ||
590 y == kControlSlot) {
591 DCHECK_EQ(x, kControlSlot) << source->DebugString();
592 DCHECK_EQ(y, kControlSlot) << dest->DebugString();
593 }
594
595 Edge* e = nullptr;
596 if (free_edges_.empty()) {
597 e = new (arena_.Alloc(sizeof(Edge))) Edge; // placement new
598 } else {
599 e = free_edges_.back();
600 free_edges_.pop_back();
601 }
602 e->id_ = edges_.size();
603 e->src_ = source;
604 e->dst_ = dest;
605 e->src_output_ = x;
606 e->dst_input_ = y;
607 CHECK(source->out_edges_.insert(e).second);
608 CHECK(dest->in_edges_.insert(e).second);
609 edges_.push_back(e);
610 ++num_edges_;
611
612 return e;
613 }
614
RemoveEdge(const Edge * e)615 void Graph::RemoveEdge(const Edge* e) {
616 TF_DCHECK_OK(IsValidNode(e->src_)) << e->src_->DebugString();
617 TF_DCHECK_OK(IsValidNode(e->dst_)) << e->dst_->DebugString();
618 CHECK_EQ(e->src_->out_edges_.erase(e), size_t{1});
619 CHECK_EQ(e->dst_->in_edges_.erase(e), size_t{1});
620 CHECK_EQ(e, edges_[e->id_]);
621 CHECK_GT(num_edges_, 0);
622
623 edges_[e->id_] = nullptr;
624 RecycleEdge(e);
625 --num_edges_;
626 }
627
RecycleEdge(const Edge * e)628 void Graph::RecycleEdge(const Edge* e) {
629 free_edges_.push_back(const_cast<Edge*>(e));
630 }
631
AddControlEdge(Node * source,Node * dest,bool allow_duplicates)632 const Edge* Graph::AddControlEdge(Node* source, Node* dest,
633 bool allow_duplicates) {
634 if (!allow_duplicates) {
635 for (const Edge* edge : dest->in_edges()) {
636 if (edge->IsControlEdge() && edge->src() == source) {
637 // The requested edge already exists.
638 return nullptr;
639 }
640 }
641 }
642 // Modify dest's NodeDef if necessary.
643 if (!source->IsSource() && !dest->IsSink() && !allow_duplicates) {
644 // Check if this input is already in dest's NodeDef.
645 const std::string new_input = strings::StrCat("^", source->name());
646 bool input_exists = false;
647 for (const std::string& input : dest->props_->node_def.input()) {
648 if (input == new_input) {
649 input_exists = true;
650 break;
651 }
652 }
653 if (!input_exists) {
654 dest->MaybeCopyOnWrite();
655 dest->props_->node_def.add_input(new_input);
656 }
657 }
658 return AddEdge(source, kControlSlot, dest, kControlSlot);
659 }
660
RemoveControlEdge(const Edge * e)661 void Graph::RemoveControlEdge(const Edge* e) {
662 if (!e->src_->IsSource() && !e->dst_->IsSink()) {
663 e->dst_->MaybeCopyOnWrite();
664 std::string e_src_name = strings::StrCat("^", e->src_->name());
665 auto* inputs = e->dst_->props_->node_def.mutable_input();
666 for (auto it = inputs->begin(); it != inputs->end(); ++it) {
667 if (*it == e_src_name) {
668 inputs->erase(it);
669 break;
670 }
671 }
672 }
673 RemoveEdge(e);
674 }
675
676 namespace {
FindEdge(const Node * dst,int index)677 const Edge* FindEdge(const Node* dst, int index) {
678 for (const Edge* e : dst->in_edges()) {
679 if (e->dst_input() == index) return e;
680 }
681 return nullptr;
682 }
683 } // namespace
684
UpdateEdge(Node * new_src,int new_src_index,Node * dst,int dst_index)685 Status Graph::UpdateEdge(Node* new_src, int new_src_index, Node* dst,
686 int dst_index) {
687 TF_RETURN_IF_ERROR(IsValidOutputTensor(new_src, new_src_index));
688 TF_RETURN_IF_ERROR(IsValidInputTensor(dst, dst_index));
689 const Edge* e = FindEdge(dst, dst_index);
690 if (e == nullptr) {
691 return errors::InvalidArgument("Couldn't find edge to ",
692 FormatNodeForError(*dst));
693 }
694 RemoveEdge(e);
695 AddEdge(new_src, new_src_index, dst, dst_index);
696 dst->MaybeCopyOnWrite();
697 (*dst->props_->node_def.mutable_input())[dst_index] =
698 strings::StrCat(new_src->name(), ":", new_src_index);
699 return OkStatus();
700 }
701
AddWhileInputHack(Node * new_src,int new_src_index,Node * dst)702 Status Graph::AddWhileInputHack(Node* new_src, int new_src_index, Node* dst) {
703 if (!dst->IsWhileNode()) {
704 return errors::Internal(
705 "dst argument to AddWhileEdgeHack should be a While op, got: ",
706 dst->DebugString());
707 }
708 TF_RETURN_IF_ERROR(IsValidOutputTensor(new_src, new_src_index));
709 // Find the current number of data inputs. We'll add the new edge to the next
710 // missing data input.
711 int dst_index = 0;
712 for (const Edge* edge : dst->in_edges()) {
713 if (edge->IsControlEdge()) continue;
714 ++dst_index;
715 }
716 TF_RETURN_IF_ERROR(IsValidInputTensor(dst, dst_index));
717 AddEdge(new_src, new_src_index, dst, dst_index);
718 dst->MaybeCopyOnWrite();
719 dst->props_->node_def.add_input(
720 strings::StrCat(new_src->name(), ":", new_src_index));
721 return OkStatus();
722 }
723
AddFunctionLibrary(const FunctionDefLibrary & fdef_lib)724 Status Graph::AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) {
725 // Need a new-enough consumer to support the functions we add to the graph.
726 if (fdef_lib.function_size() > 0 && versions_->min_consumer() < 12) {
727 versions_->set_min_consumer(12);
728 }
729 return ops_.AddLibrary(fdef_lib);
730 }
731
732 namespace {
733
AddInput(NodeDef * dst,StringPiece src_name,int src_slot)734 void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) {
735 if (src_slot == Graph::kControlSlot) {
736 dst->add_input(strings::StrCat("^", src_name));
737 } else if (src_slot == 0) {
738 dst->add_input(src_name.data(), src_name.size());
739 } else {
740 dst->add_input(strings::StrCat(src_name, ":", src_slot));
741 }
742 }
743
744 } // namespace
745
ToGraphDef(GraphDef * graph_def) const746 void Graph::ToGraphDef(GraphDef* graph_def) const {
747 ToGraphDefSubRange(graph_def, 0);
748 }
749
ToGraphDefDebug() const750 GraphDef Graph::ToGraphDefDebug() const {
751 GraphDef ret;
752 ToGraphDef(&ret);
753 return ret;
754 }
755
ToGraphDefSubRange(GraphDef * graph_def,int from_node_id) const756 void Graph::ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const {
757 graph_def->Clear();
758 *graph_def->mutable_versions() = versions();
759 *graph_def->mutable_library() = ops_.ToProto();
760
761 graph_def->mutable_node()->Reserve(std::max(1, num_nodes() - from_node_id));
762
763 std::vector<const Edge*>
764 inputs; // Construct this outside the loop for speed.
765 for (auto id = from_node_id; id < num_node_ids(); ++id) {
766 const Node* node = FindNodeId(id);
767 if (node == nullptr || !node->IsOp()) continue;
768 NodeDef* node_def = graph_def->add_node();
769 *node_def = node->def();
770
771 // Use the node's assigned device, if any, instead of the device requested
772 // in the NodeDef.
773 if (!node->assigned_device_name().empty()) {
774 node_def->set_device(node->assigned_device_name());
775 }
776
777 // Get the inputs for this Node. We make sure control inputs are
778 // after data inputs, as required by GraphDef.
779 inputs.clear();
780 inputs.resize(node->num_inputs(), nullptr);
781 for (const Edge* edge : node->in_edges()) {
782 if (edge->IsControlEdge()) {
783 inputs.push_back(edge);
784 } else {
785 DCHECK(edge->dst_input() < inputs.size())
786 << "Edge " << edge->DebugString()
787 << " is overflowing the expected number of inputs ("
788 << node->num_inputs() << ") for node " << node->DebugString();
789 CHECK(inputs[edge->dst_input()] == nullptr)
790 << "Edge " << edge->src()->name() << "->" << edge->dst()->name()
791 << " conflicts with pre-existing input edge "
792 << inputs[edge->dst_input()]->src()->name() << "->"
793 << inputs[edge->dst_input()]->dst()->name();
794
795 inputs[edge->dst_input()] = edge;
796 }
797 }
798 // Sort the control inputs for more predictable serialization.
799 std::sort(inputs.begin() + node->num_inputs(), inputs.end(),
800 [](const Edge* a, const Edge* b) -> bool {
801 return a->src()->name() < b->src()->name();
802 });
803 node_def->clear_input();
804 node_def->mutable_input()->Reserve(inputs.size());
805
806 for (size_t i = 0; i < inputs.size(); ++i) {
807 const Edge* edge = inputs[i];
808 if (edge == nullptr) {
809 if (i < node->requested_inputs().size()) {
810 node_def->add_input(node->requested_inputs()[i]);
811 } else {
812 node_def->add_input("");
813 }
814 } else {
815 const Node* src = edge->src();
816 if (!src->IsOp()) continue;
817 AddInput(node_def, src->name(), edge->src_output());
818 }
819 }
820 }
821 }
822
NewName(StringPiece prefix)823 std::string Graph::NewName(StringPiece prefix) {
824 return strings::StrCat(prefix, "/_", name_counter_++);
825 }
826
IsValidNode(const Node * node) const827 Status Graph::IsValidNode(const Node* node) const {
828 if (node == nullptr) {
829 return errors::InvalidArgument("Node is null");
830 }
831 const int id = node->id();
832 if (id < 0) {
833 return errors::InvalidArgument("node id ", id, " is less than zero");
834 }
835 if (static_cast<size_t>(id) >= nodes_.size()) {
836 return errors::InvalidArgument(
837 "node id ", id, " is >= than number of nodes in graph ", nodes_.size());
838 }
839 if (nodes_[id] != node) {
840 return errors::InvalidArgument("Node with id ", id,
841 " is different from the passed in node. "
842 "Does it belong to a different graph?");
843 }
844 return OkStatus();
845 }
846
IsValidOutputTensor(const Node * node,int idx) const847 Status Graph::IsValidOutputTensor(const Node* node, int idx) const {
848 TF_RETURN_IF_ERROR(IsValidNode(node));
849 if (idx >= node->num_outputs() || idx < 0) {
850 return errors::OutOfRange("Node '", node->name(), "' (type: '",
851 node->op_def().name(),
852 "', num of outputs: ", node->num_outputs(),
853 ") does not have ", "output ", idx);
854 }
855 return OkStatus();
856 }
857
IsValidInputTensor(const Node * node,int idx) const858 Status Graph::IsValidInputTensor(const Node* node, int idx) const {
859 TF_RETURN_IF_ERROR(IsValidNode(node));
860 if (idx >= node->num_inputs() || idx < 0) {
861 return errors::OutOfRange("Node '", node->name(), "' (type: '",
862 node->op_def().name(),
863 "', num of inputs: ", node->num_inputs(),
864 ") does not have ", "input ", idx);
865 }
866 return OkStatus();
867 }
868
AllocateNode(std::shared_ptr<NodeProperties> props,const Node * cost_node,Node::NodeClass node_class)869 Node* Graph::AllocateNode(std::shared_ptr<NodeProperties> props,
870 const Node* cost_node, Node::NodeClass node_class) {
871 Node* node = nullptr;
872 if (free_nodes_.empty()) {
873 node = new (arena_.Alloc(sizeof(Node))) Node; // placement new
874 } else {
875 node = free_nodes_.back();
876 free_nodes_.pop_back();
877 }
878 node->graph_ = this;
879 const int id = nodes_.size();
880 int cost_id = cost_node ? cost_node->cost_id() : id;
881 node->Initialize(id, cost_id, std::move(props), node_class);
882 nodes_.push_back(node);
883 ++num_nodes_;
884 return node;
885 }
886
ReleaseNode(Node * node)887 void Graph::ReleaseNode(Node* node) {
888 TF_DCHECK_OK(IsValidNode(node)) << node->DebugString();
889 nodes_[node->id()] = nullptr;
890 free_nodes_.push_back(node);
891 --num_nodes_;
892 node->Clear();
893 }
894
895 // Ensures that 'device_name' is present in the device name table, and returns
896 // the index of that device name. The index is stable, and can be used in
897 // calls to Node::set_assigned_device_name_index().
InternDeviceName(const std::string & device_name)898 int Graph::InternDeviceName(const std::string& device_name) {
899 // Special case, very common. Also, this allows us to use a single map
900 // lookup below, instead of two. The 'if (index_cell > 0)' test below
901 // relies on this check.
902 if (device_name.empty()) {
903 return 0;
904 }
905
906 int& index_cell = device_names_map_[device_name];
907 if (index_cell > 0) {
908 return index_cell;
909 }
910
911 const int index = device_names_map_.size();
912 index_cell = index;
913 device_names_.push_back(device_name);
914 return index;
915 }
916
AddWhileContext(StringPiece frame_name,std::vector<Node * > enter_nodes,std::vector<Node * > exit_nodes,OutputTensor cond_output,std::vector<OutputTensor> body_inputs,std::vector<OutputTensor> body_outputs,WhileContext ** result)917 Status Graph::AddWhileContext(StringPiece frame_name,
918 std::vector<Node*> enter_nodes,
919 std::vector<Node*> exit_nodes,
920 OutputTensor cond_output,
921 std::vector<OutputTensor> body_inputs,
922 std::vector<OutputTensor> body_outputs,
923 WhileContext** result) {
924 auto pair = while_ctxs_.insert(std::pair<std::string, WhileContext>(
925 std::string(frame_name),
926 WhileContext(frame_name, std::move(enter_nodes), std::move(exit_nodes),
927 cond_output, std::move(body_inputs),
928 std::move(body_outputs))));
929 if (!pair.second) {
930 *result = nullptr;
931 return errors::InvalidArgument("WhileContext with frame name '", frame_name,
932 "' already exists");
933 }
934 *result = &pair.first->second;
935 return OkStatus();
936 }
937
BuildNodeNameIndex() const938 std::unordered_map<std::string, Node*> Graph::BuildNodeNameIndex() const {
939 std::unordered_map<std::string, Node*> result;
940 for (Node* n : nodes()) {
941 result[n->name()] = n;
942 }
943 return result;
944 }
945
DebugString() const946 std::string Edge::DebugString() const {
947 return strings::Printf("[id=%d %s:%d -> %s:%d]", id_, src_->name().c_str(),
948 src_output_, dst_->name().c_str(), dst_input_);
949 }
950
951 } // namespace tensorflow
952