1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/graph/graph.h"
17
18 #include <vector>
19 #include "tensorflow/core/framework/graph.pb.h"
20 #include "tensorflow/core/framework/node_def.pb.h"
21 #include "tensorflow/core/framework/node_def_util.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/versions.pb.h"
24 #include "tensorflow/core/graph/while_context.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/lib/gtl/map_util.h"
27 #include "tensorflow/core/lib/hash/hash.h"
28 #include "tensorflow/core/lib/strings/strcat.h"
29 #include "tensorflow/core/lib/strings/stringprintf.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/public/version.h"
32
33 namespace tensorflow {
34
35 const int Graph::kControlSlot = -1;
36
37 struct NodeProperties {
38 public:
NodePropertiestensorflow::NodeProperties39 NodeProperties(const OpDef* op_def, const NodeDef& node_def,
40 const DataTypeSlice inputs, const DataTypeSlice outputs)
41 : op_def(op_def),
42 node_def(node_def),
43 input_types(inputs.begin(), inputs.end()),
44 output_types(outputs.begin(), outputs.end()) {}
45
46 const OpDef* op_def; // not owned
47 NodeDef node_def;
48 const DataTypeVector input_types;
49 const DataTypeVector output_types;
50 };
51
52 // Node
53
54 #define REF_CLASS(key, value) \
55 {key, value}, { "Ref" key, value }
56
57 const std::unordered_map<string, Node::NodeClass>& Node::kNodeClassTable =
58 *new std::unordered_map<string, Node::NodeClass>({
59 // Keep in same order as NodeClass values
60 REF_CLASS("Switch", NC_SWITCH),
61 REF_CLASS("Merge", NC_MERGE),
62 REF_CLASS("Enter", NC_ENTER),
63 REF_CLASS("Exit", NC_EXIT),
64 REF_CLASS("NextIteration", NC_NEXT_ITERATION),
65 {"LoopCond", NC_LOOP_COND},
66 {"ControlTrigger", NC_CONTROL_TRIGGER},
67 {"_Send", NC_SEND},
68 {"_HostSend", NC_HOST_SEND},
69 {"_Recv", NC_RECV},
70 {"_HostRecv", NC_HOST_RECV},
71 {"Const", NC_CONSTANT},
72 {"HostConst", NC_CONSTANT},
73 {"Variable", NC_VARIABLE},
74 {"VariableV2", NC_VARIABLE},
75 REF_CLASS("Identity", NC_IDENTITY),
76 {"GetSessionHandle", NC_GET_SESSION_HANDLE},
77 {"GetSessionHandleV2", NC_GET_SESSION_HANDLE},
78 {"GetSessionTensor", NC_GET_SESSION_TENSOR},
79 {"DeleteSessionTensor", NC_DELETE_SESSION_TENSOR},
80 {"Size", NC_METADATA},
81 {"Shape", NC_METADATA},
82 {"Rank", NC_METADATA},
83 {"_ScopedAllocator", NC_SCOPED_ALLOCATOR},
84 {"CollectiveReduce", NC_COLLECTIVE},
85 {"CollectiveBcastSend", NC_COLLECTIVE},
86 {"CollectiveBcastRecv", NC_COLLECTIVE},
87 {"FakeParam", NC_FAKE_PARAM},
88 {"PartitionedCall", NC_PARTITIONED_CALL},
89 {"StatefulPartitionedCall", NC_PARTITIONED_CALL},
90 // Not using the constants defined in FunctionLibraryDefinition for the
91 // 4 ops below because android inference library does not link
92 // tf.function related files.
93 {"_Arg", NC_ARG},
94 {"_DeviceArg", NC_ARG},
95 {"_Retval", NC_RETVAL},
96 {"_DeviceRetval", NC_RETVAL},
97 });
98
99 #undef REF_CLASS
100
GetNodeClassForOp(const string & ts)101 Node::NodeClass Node::GetNodeClassForOp(const string& ts) {
102 auto it = kNodeClassTable.find(ts);
103 if (it != kNodeClassTable.end()) {
104 return it->second;
105 } else {
106 return NC_OTHER;
107 }
108 }
109
DebugString() const110 string Node::DebugString() const {
111 string ret = strings::StrCat("{name:'", name(), "' id:", id_);
112 if (IsSource()) {
113 strings::StrAppend(&ret, " source}");
114 } else if (IsSink()) {
115 strings::StrAppend(&ret, " sink}");
116 } else {
117 strings::StrAppend(&ret, " op device:");
118 strings::StrAppend(&ret, "{", assigned_device_name(), "}");
119 strings::StrAppend(&ret, " def:{", SummarizeNode(*this), "}}");
120 }
121 return ret;
122 }
123
Node()124 Node::Node()
125 : id_(-1),
126 cost_id_(-1),
127 class_(NC_UNINITIALIZED),
128 props_(nullptr),
129 assigned_device_name_index_(0),
130 while_ctx_(nullptr) {}
131
Initialize(int id,int cost_id,std::shared_ptr<NodeProperties> props)132 void Node::Initialize(int id, int cost_id,
133 std::shared_ptr<NodeProperties> props) {
134 DCHECK_EQ(id_, -1);
135 DCHECK(in_edges_.empty());
136 DCHECK(out_edges_.empty());
137 id_ = id;
138 cost_id_ = cost_id;
139
140 props_ = std::move(props);
141 // Initialize the class_ based on the type string
142 class_ = GetNodeClassForOp(props_->node_def.op());
143 }
144
Clear()145 void Node::Clear() {
146 in_edges_.clear();
147 out_edges_.clear();
148 id_ = -1;
149 cost_id_ = -1;
150 class_ = NC_UNINITIALIZED;
151 props_.reset();
152 assigned_device_name_index_ = 0;
153 }
154
UpdateProperties()155 void Node::UpdateProperties() {
156 DataTypeVector inputs;
157 DataTypeVector outputs;
158 Status status =
159 InOutTypesForNode(props_->node_def, *(props_->op_def), &inputs, &outputs);
160 if (!status.ok()) {
161 LOG(ERROR) << "Failed at updating node: " << status;
162 return;
163 }
164 props_ = std::make_shared<NodeProperties>(props_->op_def, props_->node_def,
165 inputs, outputs);
166 }
167
name() const168 const string& Node::name() const { return props_->node_def.name(); }
type_string() const169 const string& Node::type_string() const { return props_->node_def.op(); }
def() const170 const NodeDef& Node::def() const { return props_->node_def; }
op_def() const171 const OpDef& Node::op_def() const { return *props_->op_def; }
172
num_inputs() const173 int32 Node::num_inputs() const { return props_->input_types.size(); }
input_type(int32 i) const174 DataType Node::input_type(int32 i) const { return props_->input_types[i]; }
input_types() const175 const DataTypeVector& Node::input_types() const { return props_->input_types; }
176
num_outputs() const177 int32 Node::num_outputs() const { return props_->output_types.size(); }
output_type(int32 o) const178 DataType Node::output_type(int32 o) const { return props_->output_types[o]; }
output_types() const179 const DataTypeVector& Node::output_types() const {
180 return props_->output_types;
181 }
182
attrs() const183 AttrSlice Node::attrs() const { return AttrSlice(def()); }
184
requested_inputs() const185 const protobuf::RepeatedPtrField<string>& Node::requested_inputs() const {
186 return def().input();
187 }
188
requested_device() const189 const string& Node::requested_device() const { return def().device(); }
190
out_nodes() const191 gtl::iterator_range<NeighborIter> Node::out_nodes() const {
192 return gtl::make_range(NeighborIter(out_edges_.begin(), false),
193 NeighborIter(out_edges_.end(), false));
194 }
195
in_nodes() const196 gtl::iterator_range<NeighborIter> Node::in_nodes() const {
197 return gtl::make_range(NeighborIter(in_edges_.begin(), true),
198 NeighborIter(in_edges_.end(), true));
199 }
200
MaybeCopyOnWrite()201 void Node::MaybeCopyOnWrite() {
202 // NodeProperties may be shared between Nodes. Make a copy if so.
203 if (!props_.unique()) {
204 props_ = std::make_shared<NodeProperties>(*props_);
205 }
206 }
207
AddAttrHelper(const string & name)208 AttrValue* Node::AddAttrHelper(const string& name) {
209 MaybeCopyOnWrite();
210 return &((*props_->node_def.mutable_attr())[name]);
211 }
212
ClearAttr(const string & name)213 void Node::ClearAttr(const string& name) {
214 MaybeCopyOnWrite();
215 (*props_->node_def.mutable_attr()).erase(name);
216 }
217
set_name(string name)218 void Node::set_name(string name) {
219 MaybeCopyOnWrite();
220 props_->node_def.set_name(std::move(name));
221 }
222
set_requested_device(const string & device)223 void Node::set_requested_device(const string& device) {
224 MaybeCopyOnWrite();
225 props_->node_def.set_device(device);
226 }
227
set_original_node_names(const std::vector<string> & names)228 void Node::set_original_node_names(const std::vector<string>& names) {
229 MaybeCopyOnWrite();
230 props_->node_def.mutable_experimental_debug_info()
231 ->clear_original_node_names();
232 if (!names.empty()) {
233 *props_->node_def.mutable_experimental_debug_info()
234 ->mutable_original_node_names() = {names.begin(), names.end()};
235 }
236 }
237
input_edge(int idx,const Edge ** e) const238 Status Node::input_edge(int idx, const Edge** e) const {
239 if (idx < 0 || idx >= num_inputs()) {
240 return errors::InvalidArgument("Invalid input_edge index: ", idx, ", Node ",
241 name(), " only has ", num_inputs(),
242 " inputs.");
243 }
244
245 // This does a linear search over the edges. In the common case,
246 // the number of elements is small enough that this search isn't
247 // expensive. Should it become a bottleneck, one can make an
248 // optimization where, if the number of edges is small, we use
249 // linear iteration, and if the number of edges is large, we perform
250 // an indexing step during construction that keeps an array of Edges
251 // indexed by pointer. This would keep the size of each Node small
252 // in the common case but make this function faster when the number
253 // of edges is large.
254 for (const Edge* edge : in_edges()) {
255 if (edge->dst_input() == idx) {
256 *e = edge;
257 return Status::OK();
258 }
259 }
260
261 return errors::NotFound("Could not find input edge ", idx, " for ", name());
262 }
263
264 // Returns a vector of the non-control input edges to a node, indexed by ID.
input_edges(std::vector<const Edge * > * input_edges) const265 Status Node::input_edges(std::vector<const Edge*>* input_edges) const {
266 input_edges->clear();
267 input_edges->resize(num_inputs(), nullptr);
268
269 for (const Edge* edge : in_edges()) {
270 if (edge->IsControlEdge()) continue;
271 if (edge->dst_input() < 0 || edge->dst_input() >= num_inputs()) {
272 return errors::Internal("Invalid edge input number ", edge->dst_input());
273 }
274 if ((*input_edges)[edge->dst_input()] != nullptr) {
275 return errors::Internal("Duplicate edge input number: ",
276 edge->dst_input());
277 }
278 (*input_edges)[edge->dst_input()] = edge;
279 }
280
281 for (int i = 0; i < num_inputs(); ++i) {
282 if ((*input_edges)[i] == nullptr) {
283 return errors::InvalidArgument("Missing edge input number: ", i);
284 }
285 }
286 return Status::OK();
287 }
288
input_node(int idx,Node ** n) const289 Status Node::input_node(int idx, Node** n) const {
290 const Edge* e;
291 TF_RETURN_IF_ERROR(input_edge(idx, &e));
292 if (e == nullptr) {
293 *n = nullptr;
294 } else {
295 *n = e->src();
296 }
297 return Status::OK();
298 }
299
input_node(int idx,const Node ** const_n) const300 Status Node::input_node(int idx, const Node** const_n) const {
301 Node* n;
302 TF_RETURN_IF_ERROR(input_node(idx, &n));
303 *const_n = n;
304 return Status::OK();
305 }
306
input_tensor(int idx,OutputTensor * t) const307 Status Node::input_tensor(int idx, OutputTensor* t) const {
308 const Edge* e;
309 TF_RETURN_IF_ERROR(input_edge(idx, &e));
310 DCHECK(e != nullptr);
311 *t = OutputTensor(e->src(), e->src_output());
312 return Status::OK();
313 }
314
315 // NodeDebugInfo
316
NodeDebugInfo(const Node & n)317 NodeDebugInfo::NodeDebugInfo(const Node& n) : NodeDebugInfo(n.def()) {}
NodeDebugInfo(const NodeDef & ndef)318 NodeDebugInfo::NodeDebugInfo(const NodeDef& ndef) : name(ndef.name()) {
319 if (ndef.has_experimental_debug_info()) {
320 const auto& names = ndef.experimental_debug_info().original_node_names();
321 original_node_names.assign(names.begin(), names.end());
322 }
323 }
324
325 // InputTensor
326
operator ==(const InputTensor & other) const327 bool InputTensor::operator==(const InputTensor& other) const {
328 return node == other.node && index == other.index;
329 }
330
operator ()(InputTensor const & s) const331 uint64 InputTensor::Hash::operator()(InputTensor const& s) const {
332 return Hash64Combine(std::hash<const Node*>()(s.node),
333 std::hash<int>()(s.index));
334 }
335
336 // OutputTensor
337
operator ==(const OutputTensor & other) const338 bool OutputTensor::operator==(const OutputTensor& other) const {
339 return node == other.node && index == other.index;
340 }
341
operator ()(OutputTensor const & s) const342 uint64 OutputTensor::Hash::operator()(OutputTensor const& s) const {
343 return Hash64Combine(std::hash<const Node*>()(s.node),
344 std::hash<int>()(s.index));
345 }
346
347 // Graph
348
Graph(const OpRegistryInterface * ops)349 Graph::Graph(const OpRegistryInterface* ops)
350 : ops_(ops, FunctionDefLibrary()),
351 versions_(new VersionDef),
352 arena_(8 << 10 /* 8kB */) {
353 versions_->set_producer(TF_GRAPH_DEF_VERSION);
354 versions_->set_min_consumer(TF_GRAPH_DEF_VERSION_MIN_CONSUMER);
355
356 // Initialize the name interning table for assigned_device_name.
357 device_names_.push_back("");
358 DCHECK_EQ(0, InternDeviceName(""));
359
360 // Source and sink have no endpoints, just control edges.
361 NodeDef def;
362 def.set_name("_SOURCE");
363 def.set_op("NoOp");
364 Status status;
365 Node* source = AddNode(def, &status);
366 TF_CHECK_OK(status);
367 CHECK_EQ(source->id(), kSourceId);
368
369 def.set_name("_SINK");
370 Node* sink = AddNode(def, &status);
371 TF_CHECK_OK(status);
372 CHECK_EQ(sink->id(), kSinkId);
373
374 AddControlEdge(source, sink);
375 }
376
Graph(const FunctionLibraryDefinition & flib_def)377 Graph::Graph(const FunctionLibraryDefinition& flib_def)
378 : Graph(flib_def.default_registry()) {
379 // Need a new-enough consumer to support the functions we add to the graph.
380 if (flib_def.ToProto().function_size() > 0 &&
381 versions_->min_consumer() < 12) {
382 versions_->set_min_consumer(12);
383 }
384 Status s = ops_.AddLibrary(flib_def);
385 CHECK(s.ok()) << s.error_message();
386 }
387
~Graph()388 Graph::~Graph() {
389 // Manually call the destructors for all the Nodes we constructed using
390 // placement new.
391 for (Node* node : nodes_) {
392 if (node != nullptr) {
393 node->~Node();
394 }
395 }
396 for (Node* node : free_nodes_) {
397 node->~Node();
398 }
399 // Edges have no destructor, and we arena-allocated them, so no need to
400 // destroy them.
401 }
402
versions() const403 const VersionDef& Graph::versions() const { return *versions_; }
set_versions(const VersionDef & versions)404 void Graph::set_versions(const VersionDef& versions) { *versions_ = versions; }
405
AddNode(const NodeDef & node_def,Status * status)406 Node* Graph::AddNode(const NodeDef& node_def, Status* status) {
407 const OpDef* op_def;
408 status->Update(ops_.LookUpOpDef(node_def.op(), &op_def));
409 if (!status->ok()) return nullptr;
410
411 DataTypeVector inputs;
412 DataTypeVector outputs;
413 status->Update(InOutTypesForNode(node_def, *op_def, &inputs, &outputs));
414 if (!status->ok()) {
415 *status = AttachDef(*status, node_def);
416 return nullptr;
417 }
418
419 Node* node = AllocateNode(
420 std::make_shared<NodeProperties>(op_def, node_def, inputs, outputs),
421 nullptr);
422 return node;
423 }
424
CopyNode(const Node * node)425 Node* Graph::CopyNode(const Node* node) {
426 DCHECK(!node->IsSource());
427 DCHECK(!node->IsSink());
428 Node* copy = AllocateNode(node->props_, node);
429 copy->set_assigned_device_name(node->assigned_device_name());
430
431 // Since the OpDef of a function may be owned by the Graph that owns 'node',
432 // relookup the OpDef in the target graph. If it differs, then clone the
433 // node properties with the updated OpDef.
434 const OpDef* op_def;
435 TF_CHECK_OK(ops_.LookUpOpDef(node->type_string(), &op_def));
436 if (op_def != node->props_->op_def) {
437 copy->MaybeCopyOnWrite();
438 copy->props_->op_def = op_def;
439 }
440
441 return copy;
442 }
443
RemoveNode(Node * node)444 void Graph::RemoveNode(Node* node) {
445 TF_DCHECK_OK(IsValidNode(node)) << node->DebugString();
446 DCHECK(!node->IsSource());
447 DCHECK(!node->IsSink());
448
449 // Remove any edges involving this node.
450 while (!node->in_edges_.empty()) {
451 RemoveEdge(*node->in_edges_.begin());
452 }
453 while (!node->out_edges_.empty()) {
454 RemoveEdge(*node->out_edges_.begin());
455 }
456 ReleaseNode(node);
457 }
458
AddEdge(Node * source,int x,Node * dest,int y)459 const Edge* Graph::AddEdge(Node* source, int x, Node* dest, int y) {
460 TF_DCHECK_OK(IsValidNode(source)) << source->DebugString();
461 TF_DCHECK_OK(IsValidNode(dest)) << dest->DebugString();
462
463 // source/sink must only be linked via control slots, and
464 // control slots must only be linked to control slots.
465 if (source == source_node() || dest == sink_node() || x == kControlSlot ||
466 y == kControlSlot) {
467 DCHECK_EQ(x, kControlSlot) << source->DebugString();
468 DCHECK_EQ(y, kControlSlot) << dest->DebugString();
469 }
470
471 Edge* e = nullptr;
472 if (free_edges_.empty()) {
473 e = new (arena_.Alloc(sizeof(Edge))) Edge; // placement new
474 } else {
475 e = free_edges_.back();
476 free_edges_.pop_back();
477 }
478 e->id_ = edges_.size();
479 e->src_ = source;
480 e->dst_ = dest;
481 e->src_output_ = x;
482 e->dst_input_ = y;
483 CHECK(source->out_edges_.insert(e).second);
484 CHECK(dest->in_edges_.insert(e).second);
485 edges_.push_back(e);
486 ++num_edges_;
487 return e;
488 }
489
RemoveEdge(const Edge * e)490 void Graph::RemoveEdge(const Edge* e) {
491 TF_DCHECK_OK(IsValidNode(e->src_)) << e->src_->DebugString();
492 TF_DCHECK_OK(IsValidNode(e->dst_)) << e->dst_->DebugString();
493 CHECK_EQ(e->src_->out_edges_.erase(e), size_t{1});
494 CHECK_EQ(e->dst_->in_edges_.erase(e), size_t{1});
495 CHECK_EQ(e, edges_[e->id_]);
496 CHECK_GT(num_edges_, 0);
497
498 edges_[e->id_] = nullptr;
499
500 Edge* del = const_cast<Edge*>(e);
501 del->src_ = nullptr;
502 del->dst_ = nullptr;
503 del->id_ = -1;
504 del->src_output_ = kControlSlot - 1;
505 del->dst_input_ = kControlSlot - 1;
506 free_edges_.push_back(del);
507 --num_edges_;
508 }
509
AddControlEdge(Node * source,Node * dest,bool allow_duplicates)510 const Edge* Graph::AddControlEdge(Node* source, Node* dest,
511 bool allow_duplicates) {
512 if (!allow_duplicates) {
513 for (const Edge* edge : dest->in_edges()) {
514 if (edge->IsControlEdge() && edge->src() == source) {
515 // The requested edge already exists.
516 return nullptr;
517 }
518 }
519 }
520 // Modify dest's NodeDef if necessary.
521 if (!source->IsSource() && !dest->IsSink() && !allow_duplicates) {
522 // Check if this input is already in dest's NodeDef.
523 const string new_input = strings::StrCat("^", source->name());
524 bool input_exists = false;
525 for (const string& input : dest->props_->node_def.input()) {
526 if (input == new_input) {
527 input_exists = true;
528 break;
529 }
530 }
531 if (!input_exists) {
532 dest->MaybeCopyOnWrite();
533 dest->props_->node_def.add_input(new_input);
534 }
535 }
536 return AddEdge(source, kControlSlot, dest, kControlSlot);
537 }
538
RemoveControlEdge(const Edge * e)539 void Graph::RemoveControlEdge(const Edge* e) {
540 if (!e->src_->IsSource() && !e->dst_->IsSink()) {
541 e->dst_->MaybeCopyOnWrite();
542 string e_src_name = strings::StrCat("^", e->src_->name());
543 auto* inputs = e->dst_->props_->node_def.mutable_input();
544 for (auto it = inputs->begin(); it != inputs->end(); ++it) {
545 if (*it == e_src_name) {
546 inputs->erase(it);
547 break;
548 }
549 }
550 }
551 RemoveEdge(e);
552 }
553
554 namespace {
FindEdge(const Node * dst,int index)555 const Edge* FindEdge(const Node* dst, int index) {
556 for (const Edge* e : dst->in_edges()) {
557 if (e->dst_input() == index) return e;
558 }
559 return nullptr;
560 }
561 } // namespace
562
UpdateEdge(Node * new_src,int new_src_index,Node * dst,int dst_index)563 Status Graph::UpdateEdge(Node* new_src, int new_src_index, Node* dst,
564 int dst_index) {
565 TF_RETURN_IF_ERROR(IsValidOutputTensor(new_src, new_src_index));
566 TF_RETURN_IF_ERROR(IsValidInputTensor(dst, dst_index));
567 const Edge* e = FindEdge(dst, dst_index);
568 if (e == nullptr) {
569 return errors::InvalidArgument("Couldn't find edge to ",
570 FormatNodeForError(*dst));
571 }
572 RemoveEdge(e);
573 AddEdge(new_src, new_src_index, dst, dst_index);
574 dst->MaybeCopyOnWrite();
575 (*dst->props_->node_def.mutable_input())[dst_index] =
576 strings::StrCat(new_src->name(), ":", new_src_index);
577 return Status::OK();
578 }
579
AddWhileInputHack(Node * new_src,int new_src_index,Node * dst)580 Status Graph::AddWhileInputHack(Node* new_src, int new_src_index, Node* dst) {
581 if (dst->type_string() != "While") {
582 return errors::Internal(
583 "dst argument to AddWhileEdgeHack should be a While op, got: ",
584 dst->DebugString());
585 }
586 TF_RETURN_IF_ERROR(IsValidOutputTensor(new_src, new_src_index));
587 // Find the current number of data inputs. We'll add the new edge to the next
588 // missing data input.
589 int dst_index = 0;
590 for (const Edge* edge : dst->in_edges()) {
591 if (edge->IsControlEdge()) continue;
592 ++dst_index;
593 }
594 TF_RETURN_IF_ERROR(IsValidInputTensor(dst, dst_index));
595 AddEdge(new_src, new_src_index, dst, dst_index);
596 dst->MaybeCopyOnWrite();
597 dst->props_->node_def.add_input(
598 strings::StrCat(new_src->name(), ":", new_src_index));
599 return Status::OK();
600 }
601
AddFunctionLibrary(const FunctionDefLibrary & fdef_lib)602 Status Graph::AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) {
603 // Need a new-enough consumer to support the functions we add to the graph.
604 if (fdef_lib.function_size() > 0 && versions_->min_consumer() < 12) {
605 versions_->set_min_consumer(12);
606 }
607 return ops_.AddLibrary(fdef_lib);
608 }
609
610 namespace {
611
AddInput(NodeDef * dst,StringPiece src_name,int src_slot)612 void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) {
613 if (src_slot == Graph::kControlSlot) {
614 dst->add_input(strings::StrCat("^", src_name));
615 } else if (src_slot == 0) {
616 dst->add_input(src_name.data(), src_name.size());
617 } else {
618 dst->add_input(strings::StrCat(src_name, ":", src_slot));
619 }
620 }
621
622 } // namespace
623
ToGraphDef(GraphDef * graph_def) const624 void Graph::ToGraphDef(GraphDef* graph_def) const {
625 ToGraphDefSubRange(graph_def, 0);
626 }
627
ToGraphDefDebug() const628 GraphDef Graph::ToGraphDefDebug() const {
629 GraphDef ret;
630 ToGraphDef(&ret);
631 return ret;
632 }
633
ToGraphDefSubRange(GraphDef * graph_def,int from_node_id) const634 void Graph::ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const {
635 graph_def->Clear();
636 *graph_def->mutable_versions() = versions();
637 *graph_def->mutable_library() = ops_.ToProto();
638
639 graph_def->mutable_node()->Reserve(std::max(1, num_nodes() - from_node_id));
640
641 std::vector<const Edge*>
642 inputs; // Construct this outside the loop for speed.
643 for (auto id = from_node_id; id < num_node_ids(); ++id) {
644 const Node* node = FindNodeId(id);
645 if (node == nullptr || !node->IsOp()) continue;
646 NodeDef* node_def = graph_def->add_node();
647 *node_def = node->def();
648
649 // Use the node's assigned device, if any, instead of the device requested
650 // in the NodeDef.
651 if (!node->assigned_device_name().empty()) {
652 node_def->set_device(node->assigned_device_name());
653 }
654
655 // Get the inputs for this Node. We make sure control inputs are
656 // after data inputs, as required by GraphDef.
657 inputs.clear();
658 inputs.resize(node->num_inputs(), nullptr);
659 for (const Edge* edge : node->in_edges()) {
660 if (edge->IsControlEdge()) {
661 inputs.push_back(edge);
662 } else {
663 CHECK(inputs[edge->dst_input()] == nullptr)
664 << "Edge " << edge->src()->DebugString() << ":"
665 << edge->dst()->DebugString() << " with dst_input "
666 << edge->dst_input() << " and had pre-existing input edge "
667 << inputs[edge->dst_input()]->src()->DebugString() << ":"
668 << inputs[edge->dst_input()]->dst()->DebugString();
669
670 inputs[edge->dst_input()] = edge;
671 }
672 }
673 // Sort the control inputs for more predictable serialization.
674 std::sort(inputs.begin() + node->num_inputs(), inputs.end(),
675 [](const Edge* a, const Edge* b) -> bool {
676 return a->src()->name() < b->src()->name();
677 });
678 node_def->clear_input();
679 node_def->mutable_input()->Reserve(inputs.size());
680
681 for (size_t i = 0; i < inputs.size(); ++i) {
682 const Edge* edge = inputs[i];
683 if (edge == nullptr) {
684 if (i < node->requested_inputs().size()) {
685 node_def->add_input(node->requested_inputs()[i]);
686 } else {
687 node_def->add_input("");
688 }
689 } else {
690 const Node* src = edge->src();
691 if (!src->IsOp()) continue;
692 AddInput(node_def, src->name(), edge->src_output());
693 }
694 }
695 }
696 }
697
NewName(StringPiece prefix)698 string Graph::NewName(StringPiece prefix) {
699 return strings::StrCat(prefix, "/_", name_counter_++);
700 }
701
IsValidNode(const Node * node) const702 Status Graph::IsValidNode(const Node* node) const {
703 if (node == nullptr) {
704 return errors::InvalidArgument("Node is null");
705 }
706 const int id = node->id();
707 if (id < 0) {
708 return errors::InvalidArgument("node id ", id, " is less than zero");
709 }
710 if (static_cast<size_t>(id) >= nodes_.size()) {
711 return errors::InvalidArgument(
712 "node id ", id, " is >= than number of nodes in graph ", nodes_.size());
713 }
714 if (nodes_[id] != node) {
715 return errors::InvalidArgument("Node with id ", id,
716 " is different from the passed in node. "
717 "Does it belong to a different graph?");
718 }
719 return Status::OK();
720 }
721
IsValidOutputTensor(const Node * node,int idx) const722 Status Graph::IsValidOutputTensor(const Node* node, int idx) const {
723 TF_RETURN_IF_ERROR(IsValidNode(node));
724 if (idx >= node->num_outputs() || idx < 0) {
725 return errors::OutOfRange("Node '", node->name(), "' (type: '",
726 node->op_def().name(),
727 "', num of outputs: ", node->num_outputs(),
728 ") does not have ", "output ", idx);
729 }
730 return Status::OK();
731 }
732
IsValidInputTensor(const Node * node,int idx) const733 Status Graph::IsValidInputTensor(const Node* node, int idx) const {
734 TF_RETURN_IF_ERROR(IsValidNode(node));
735 if (idx >= node->num_inputs() || idx < 0) {
736 return errors::OutOfRange("Node '", node->name(), "' (type: '",
737 node->op_def().name(),
738 "', num of inputs: ", node->num_inputs(),
739 ") does not have ", "input ", idx);
740 }
741 return Status::OK();
742 }
743
AllocateNode(std::shared_ptr<NodeProperties> props,const Node * cost_node)744 Node* Graph::AllocateNode(std::shared_ptr<NodeProperties> props,
745 const Node* cost_node) {
746 Node* node = nullptr;
747 if (free_nodes_.empty()) {
748 node = new (arena_.Alloc(sizeof(Node))) Node; // placement new
749 } else {
750 node = free_nodes_.back();
751 free_nodes_.pop_back();
752 }
753 node->graph_ = this;
754 const int id = nodes_.size();
755 int cost_id = cost_node ? cost_node->cost_id() : id;
756 node->Initialize(id, cost_id, std::move(props));
757 nodes_.push_back(node);
758 ++num_nodes_;
759 return node;
760 }
761
ReleaseNode(Node * node)762 void Graph::ReleaseNode(Node* node) {
763 TF_DCHECK_OK(IsValidNode(node)) << node->DebugString();
764 nodes_[node->id()] = nullptr;
765 free_nodes_.push_back(node);
766 --num_nodes_;
767 node->Clear();
768 }
769
770 // Ensures that 'device_name' is present in the device name table, and returns
771 // the index of that device name. The index is stable, and can be used in
772 // calls to Node::set_assigned_device_name_index().
InternDeviceName(const string & device_name)773 int Graph::InternDeviceName(const string& device_name) {
774 // Special case, very common. Also, this allows us to use a single map
775 // lookup below, instead of two. The 'if (index_cell > 0)' test below
776 // relies on this check.
777 if (device_name.empty()) {
778 return 0;
779 }
780
781 int& index_cell = device_names_map_[device_name];
782 if (index_cell > 0) {
783 return index_cell;
784 }
785
786 const int index = device_names_map_.size();
787 index_cell = index;
788 device_names_.push_back(device_name);
789 return index;
790 }
791
AddWhileContext(StringPiece frame_name,std::vector<Node * > enter_nodes,std::vector<Node * > exit_nodes,OutputTensor cond_output,std::vector<OutputTensor> body_inputs,std::vector<OutputTensor> body_outputs,WhileContext ** result)792 Status Graph::AddWhileContext(StringPiece frame_name,
793 std::vector<Node*> enter_nodes,
794 std::vector<Node*> exit_nodes,
795 OutputTensor cond_output,
796 std::vector<OutputTensor> body_inputs,
797 std::vector<OutputTensor> body_outputs,
798 WhileContext** result) {
799 auto pair = while_ctxs_.insert(std::pair<string, WhileContext>(
800 string(frame_name),
801 WhileContext(frame_name, std::move(enter_nodes), std::move(exit_nodes),
802 cond_output, std::move(body_inputs),
803 std::move(body_outputs))));
804 if (!pair.second) {
805 *result = nullptr;
806 return errors::InvalidArgument("WhileContext with frame name '", frame_name,
807 "' already exists");
808 }
809 *result = &pair.first->second;
810 return Status::OK();
811 }
812
BuildNodeNameIndex() const813 std::unordered_map<string, Node*> Graph::BuildNodeNameIndex() const {
814 std::unordered_map<string, Node*> result;
815 for (Node* n : nodes()) {
816 result[n->name()] = n;
817 }
818 return result;
819 }
820
DebugString() const821 string Edge::DebugString() const {
822 return strings::Printf("[id=%d %s:%d -> %s:%d]", id_, src_->name().c_str(),
823 src_output_, dst_->name().c_str(), dst_input_);
824 }
825
826 } // namespace tensorflow
827