1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/graph/graph.h"
17
18 #include <vector>
19 #include "tensorflow/core/framework/graph.pb.h"
20 #include "tensorflow/core/framework/node_def.pb.h"
21 #include "tensorflow/core/framework/node_def_util.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/versions.pb.h"
24 #include "tensorflow/core/graph/while_context.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/lib/gtl/map_util.h"
27 #include "tensorflow/core/lib/strings/strcat.h"
28 #include "tensorflow/core/lib/strings/stringprintf.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/public/version.h"
31
32 namespace tensorflow {
33
34 const int Graph::kControlSlot = -1;
35
36 class NodeProperties {
37 public:
NodeProperties(const OpDef * op_def,const NodeDef & node_def,const DataTypeSlice inputs,const DataTypeSlice outputs)38 NodeProperties(const OpDef* op_def, const NodeDef& node_def,
39 const DataTypeSlice inputs, const DataTypeSlice outputs)
40 : op_def(op_def),
41 node_def(node_def),
42 input_types(inputs.begin(), inputs.end()),
43 output_types(outputs.begin(), outputs.end()) {}
44
45 const OpDef* op_def; // not owned
46 NodeDef node_def;
47 const DataTypeVector input_types;
48 const DataTypeVector output_types;
49 };
50
51 // Node
52
53 #define REF_CLASS(key, value) \
54 {key, value}, { "Ref" key, value }
55
56 const std::unordered_map<string, Node::NodeClass>& Node::kNodeClassTable =
57 *new std::unordered_map<string, Node::NodeClass>({
58 // Keep in same order as NodeClass values
59 REF_CLASS("Switch", NC_SWITCH),
60 REF_CLASS("Merge", NC_MERGE),
61 REF_CLASS("Enter", NC_ENTER),
62 REF_CLASS("Exit", NC_EXIT),
63 REF_CLASS("NextIteration", NC_NEXT_ITERATION),
64 {"LoopCond", NC_LOOP_COND},
65 {"ControlTrigger", NC_CONTROL_TRIGGER},
66 {"_Send", NC_SEND},
67 {"_HostSend", NC_HOST_SEND},
68 {"_Recv", NC_RECV},
69 {"_HostRecv", NC_HOST_RECV},
70 {"Const", NC_CONSTANT},
71 {"HostConst", NC_CONSTANT},
72 {"Variable", NC_VARIABLE},
73 {"VariableV2", NC_VARIABLE},
74 REF_CLASS("Identity", NC_IDENTITY),
75 {"GetSessionHandle", NC_GET_SESSION_HANDLE},
76 {"GetSessionHandleV2", NC_GET_SESSION_HANDLE},
77 {"GetSessionTensor", NC_GET_SESSION_TENSOR},
78 {"DeleteSessionTensor", NC_DELETE_SESSION_TENSOR},
79 {"Size", NC_METADATA},
80 {"Shape", NC_METADATA},
81 {"Rank", NC_METADATA},
82 });
83
84 #undef REF_CLASS
85
GetNodeClassForOp(const string & ts)86 Node::NodeClass Node::GetNodeClassForOp(const string& ts) {
87 auto it = kNodeClassTable.find(ts);
88 if (it != kNodeClassTable.end()) {
89 return it->second;
90 } else {
91 return NC_OTHER;
92 }
93 }
94
DebugString() const95 string Node::DebugString() const {
96 string ret = strings::StrCat("{name:'", name(), "' id:", id_);
97 if (IsSource()) {
98 strings::StrAppend(&ret, " source}");
99 } else if (IsSink()) {
100 strings::StrAppend(&ret, " sink}");
101 } else {
102 strings::StrAppend(&ret, " op device:");
103 strings::StrAppend(&ret, "{", assigned_device_name(), "}");
104 strings::StrAppend(&ret, " def:{", SummarizeNode(*this), "}}");
105 }
106 return ret;
107 }
108
Node()109 Node::Node()
110 : id_(-1),
111 cost_id_(-1),
112 class_(NC_UNINITIALIZED),
113 props_(nullptr),
114 assigned_device_name_index_(0),
115 while_ctx_(nullptr) {}
116
Initialize(int id,int cost_id,std::shared_ptr<NodeProperties> props)117 void Node::Initialize(int id, int cost_id,
118 std::shared_ptr<NodeProperties> props) {
119 DCHECK_EQ(id_, -1);
120 DCHECK(in_edges_.empty());
121 DCHECK(out_edges_.empty());
122 id_ = id;
123 cost_id_ = cost_id;
124
125 props_ = std::move(props);
126 // Initialize the class_ based on the type string
127 class_ = GetNodeClassForOp(props_->node_def.op());
128 }
129
Clear()130 void Node::Clear() {
131 in_edges_.clear();
132 out_edges_.clear();
133 id_ = -1;
134 cost_id_ = -1;
135 class_ = NC_UNINITIALIZED;
136 props_.reset();
137 assigned_device_name_index_ = 0;
138 }
139
name() const140 const string& Node::name() const { return props_->node_def.name(); }
type_string() const141 const string& Node::type_string() const { return props_->node_def.op(); }
def() const142 const NodeDef& Node::def() const { return props_->node_def; }
op_def() const143 const OpDef& Node::op_def() const { return *props_->op_def; }
144
num_inputs() const145 int32 Node::num_inputs() const { return props_->input_types.size(); }
input_type(int32 i) const146 DataType Node::input_type(int32 i) const { return props_->input_types[i]; }
input_types() const147 const DataTypeVector& Node::input_types() const { return props_->input_types; }
148
num_outputs() const149 int32 Node::num_outputs() const { return props_->output_types.size(); }
output_type(int32 o) const150 DataType Node::output_type(int32 o) const { return props_->output_types[o]; }
output_types() const151 const DataTypeVector& Node::output_types() const {
152 return props_->output_types;
153 }
154
attrs() const155 AttrSlice Node::attrs() const { return AttrSlice(def()); }
156
requested_inputs() const157 const protobuf::RepeatedPtrField<string>& Node::requested_inputs() const {
158 return def().input();
159 }
160
requested_device() const161 const string& Node::requested_device() const { return def().device(); }
162
out_nodes() const163 gtl::iterator_range<NeighborIter> Node::out_nodes() const {
164 return gtl::make_range(NeighborIter(out_edges_.begin(), false),
165 NeighborIter(out_edges_.end(), false));
166 }
167
in_nodes() const168 gtl::iterator_range<NeighborIter> Node::in_nodes() const {
169 return gtl::make_range(NeighborIter(in_edges_.begin(), true),
170 NeighborIter(in_edges_.end(), true));
171 }
172
MaybeCopyOnWrite()173 void Node::MaybeCopyOnWrite() {
174 // NodeProperties may be shared between Nodes. Make a copy if so.
175 if (!props_.unique()) {
176 props_ = std::make_shared<NodeProperties>(*props_);
177 }
178 }
179
AddAttrHelper(const string & name)180 AttrValue* Node::AddAttrHelper(const string& name) {
181 MaybeCopyOnWrite();
182 return &((*props_->node_def.mutable_attr())[name]);
183 }
184
ClearAttr(const string & name)185 void Node::ClearAttr(const string& name) {
186 MaybeCopyOnWrite();
187 (*props_->node_def.mutable_attr()).erase(name);
188 }
189
set_requested_device(const string & device)190 void Node::set_requested_device(const string& device) {
191 MaybeCopyOnWrite();
192 props_->node_def.set_device(device);
193 }
194
input_edge(int idx,const Edge ** e) const195 Status Node::input_edge(int idx, const Edge** e) const {
196 if (idx < 0 || idx >= num_inputs()) {
197 return errors::InvalidArgument("Invalid input_edge index: ", idx, ", Node ",
198 name(), " only has ", num_inputs(),
199 " inputs.");
200 }
201
202 // This does a linear search over the edges. In the common case,
203 // the number of elements is small enough that this search isn't
204 // expensive. Should it become a bottleneck, one can make an
205 // optimization where, if the number of edges is small, we use
206 // linear iteration, and if the number of edges is large, we perform
207 // an indexing step during construction that keeps an array of Edges
208 // indexed by pointer. This would keep the size of each Node small
209 // in the common case but make this function faster when the number
210 // of edges is large.
211 for (const Edge* edge : in_edges()) {
212 if (edge->dst_input() == idx) {
213 *e = edge;
214 return Status::OK();
215 }
216 }
217
218 return errors::NotFound("Could not find input edge ", idx, " for ", name());
219 }
220
221 // Returns a vector of the non-control input edges to a node, indexed by ID.
input_edges(std::vector<const Edge * > * input_edges) const222 Status Node::input_edges(std::vector<const Edge*>* input_edges) const {
223 input_edges->clear();
224 input_edges->resize(num_inputs(), nullptr);
225
226 for (const Edge* edge : in_edges()) {
227 if (edge->IsControlEdge()) continue;
228 if (edge->dst_input() < 0 || edge->dst_input() >= num_inputs()) {
229 return errors::Internal("Invalid edge input number ", edge->dst_input());
230 }
231 if ((*input_edges)[edge->dst_input()] != nullptr) {
232 return errors::Internal("Duplicate edge input number: ",
233 edge->dst_input());
234 }
235 (*input_edges)[edge->dst_input()] = edge;
236 }
237
238 for (int i = 0; i < num_inputs(); ++i) {
239 if ((*input_edges)[i] == nullptr) {
240 return errors::InvalidArgument("Missing edge input number: ", i);
241 }
242 }
243 return Status::OK();
244 }
245
input_node(int idx,Node ** n) const246 Status Node::input_node(int idx, Node** n) const {
247 const Edge* e;
248 TF_RETURN_IF_ERROR(input_edge(idx, &e));
249 if (e == nullptr) {
250 *n = nullptr;
251 } else {
252 *n = e->src();
253 }
254 return Status::OK();
255 }
256
input_node(int idx,const Node ** const_n) const257 Status Node::input_node(int idx, const Node** const_n) const {
258 Node* n;
259 TF_RETURN_IF_ERROR(input_node(idx, &n));
260 *const_n = n;
261 return Status::OK();
262 }
263
264 // Graph
265
Graph(const OpRegistryInterface * ops)266 Graph::Graph(const OpRegistryInterface* ops)
267 : ops_(ops, FunctionDefLibrary()),
268 versions_(new VersionDef),
269 arena_(8 << 10 /* 8kB */) {
270 versions_->set_producer(TF_GRAPH_DEF_VERSION);
271 versions_->set_min_consumer(TF_GRAPH_DEF_VERSION_MIN_CONSUMER);
272
273 // Initialize the name interning table for assigned_device_name.
274 device_names_.push_back("");
275 DCHECK_EQ(0, InternDeviceName(""));
276
277 // Source and sink have no endpoints, just control edges.
278 NodeDef def;
279 def.set_name("_SOURCE");
280 def.set_op("NoOp");
281 Status status;
282 Node* source = AddNode(def, &status);
283 TF_CHECK_OK(status);
284 CHECK_EQ(source->id(), kSourceId);
285
286 def.set_name("_SINK");
287 Node* sink = AddNode(def, &status);
288 TF_CHECK_OK(status);
289 CHECK_EQ(sink->id(), kSinkId);
290
291 AddControlEdge(source, sink);
292 }
293
Graph(const FunctionLibraryDefinition & flib_def)294 Graph::Graph(const FunctionLibraryDefinition& flib_def)
295 : Graph(flib_def.default_registry()) {
296 // Need a new-enough consumer to support the functions we add to the graph.
297 if (flib_def.ToProto().function_size() > 0 &&
298 versions_->min_consumer() < 12) {
299 versions_->set_min_consumer(12);
300 }
301 Status s = ops_.AddLibrary(flib_def);
302 CHECK(s.ok()) << s.error_message();
303 }
304
~Graph()305 Graph::~Graph() {
306 // Manually call the destructors for all the Nodes we constructed using
307 // placement new.
308 for (Node* node : nodes_) {
309 if (node != nullptr) {
310 node->~Node();
311 }
312 }
313 for (Node* node : free_nodes_) {
314 node->~Node();
315 }
316 // Edges have no destructor, and we arena-allocated them, so no need to
317 // destroy them.
318 }
319
versions() const320 const VersionDef& Graph::versions() const { return *versions_; }
set_versions(const VersionDef & versions)321 void Graph::set_versions(const VersionDef& versions) { *versions_ = versions; }
322
AddNode(const NodeDef & node_def,Status * status)323 Node* Graph::AddNode(const NodeDef& node_def, Status* status) {
324 const OpDef* op_def;
325 status->Update(ops_.LookUpOpDef(node_def.op(), &op_def));
326 if (!status->ok()) return nullptr;
327
328 DataTypeVector inputs;
329 DataTypeVector outputs;
330 status->Update(InOutTypesForNode(node_def, *op_def, &inputs, &outputs));
331 if (!status->ok()) {
332 *status = AttachDef(*status, node_def);
333 return nullptr;
334 }
335
336 Node* node = AllocateNode(
337 std::make_shared<NodeProperties>(op_def, node_def, inputs, outputs),
338 nullptr);
339 return node;
340 }
341
CopyNode(Node * node)342 Node* Graph::CopyNode(Node* node) {
343 DCHECK(!node->IsSource());
344 DCHECK(!node->IsSink());
345 Node* copy = AllocateNode(node->props_, node);
346 copy->set_assigned_device_name(node->assigned_device_name());
347
348 // Since the OpDef of a function may be owned by the Graph that owns 'node',
349 // relookup the OpDef in the target graph. If it differs, then clone the
350 // node properties with the updated OpDef.
351 const OpDef* op_def;
352 TF_CHECK_OK(ops_.LookUpOpDef(node->type_string(), &op_def));
353 if (op_def != node->props_->op_def) {
354 copy->MaybeCopyOnWrite();
355 copy->props_->op_def = op_def;
356 }
357
358 return copy;
359 }
360
RemoveNode(Node * node)361 void Graph::RemoveNode(Node* node) {
362 TF_DCHECK_OK(IsValidNode(node)) << node->DebugString();
363 DCHECK(!node->IsSource());
364 DCHECK(!node->IsSink());
365
366 // Remove any edges involving this node.
367 while (!node->in_edges_.empty()) {
368 RemoveEdge(*node->in_edges_.begin());
369 }
370 while (!node->out_edges_.empty()) {
371 RemoveEdge(*node->out_edges_.begin());
372 }
373 ReleaseNode(node);
374 }
375
AddEdge(Node * source,int x,Node * dest,int y)376 const Edge* Graph::AddEdge(Node* source, int x, Node* dest, int y) {
377 TF_DCHECK_OK(IsValidNode(source)) << source->DebugString();
378 TF_DCHECK_OK(IsValidNode(dest)) << dest->DebugString();
379
380 // source/sink must only be linked via control slots, and
381 // control slots must only be linked to control slots.
382 if (source == source_node() || dest == sink_node() || x == kControlSlot ||
383 y == kControlSlot) {
384 DCHECK_EQ(x, kControlSlot) << source->DebugString();
385 DCHECK_EQ(y, kControlSlot) << dest->DebugString();
386 }
387
388 Edge* e = nullptr;
389 if (free_edges_.empty()) {
390 e = new (arena_.Alloc(sizeof(Edge))) Edge; // placement new
391 } else {
392 e = free_edges_.back();
393 free_edges_.pop_back();
394 }
395 e->id_ = edges_.size();
396 e->src_ = source;
397 e->dst_ = dest;
398 e->src_output_ = x;
399 e->dst_input_ = y;
400 CHECK(source->out_edges_.insert(e).second);
401 CHECK(dest->in_edges_.insert(e).second);
402 edges_.push_back(e);
403 ++num_edges_;
404 return e;
405 }
406
RemoveEdge(const Edge * e)407 void Graph::RemoveEdge(const Edge* e) {
408 TF_DCHECK_OK(IsValidNode(e->src_)) << e->src_->DebugString();
409 TF_DCHECK_OK(IsValidNode(e->dst_)) << e->dst_->DebugString();
410 CHECK_EQ(e->src_->out_edges_.erase(e), size_t{1});
411 CHECK_EQ(e->dst_->in_edges_.erase(e), size_t{1});
412 CHECK_EQ(e, edges_[e->id_]);
413 CHECK_GT(num_edges_, 0);
414
415 edges_[e->id_] = nullptr;
416
417 Edge* del = const_cast<Edge*>(e);
418 del->src_ = nullptr;
419 del->dst_ = nullptr;
420 del->id_ = -1;
421 del->src_output_ = kControlSlot - 1;
422 del->dst_input_ = kControlSlot - 1;
423 free_edges_.push_back(del);
424 --num_edges_;
425 }
426
AddControlEdge(Node * source,Node * dest,bool allow_duplicates)427 const Edge* Graph::AddControlEdge(Node* source, Node* dest,
428 bool allow_duplicates) {
429 if (!allow_duplicates) {
430 for (const Edge* edge : dest->in_edges()) {
431 if (edge->IsControlEdge() && edge->src() == source) {
432 // The requested edge already exists.
433 return nullptr;
434 }
435 }
436 }
437 // Modify dest's NodeDef if necessary.
438 if (!source->IsSource() && !dest->IsSink() && !allow_duplicates) {
439 // Check if this input is already in dest's NodeDef.
440 const string new_input = strings::StrCat("^", source->name());
441 bool input_exists = false;
442 for (const string& input : dest->props_->node_def.input()) {
443 if (input == new_input) {
444 input_exists = true;
445 break;
446 }
447 }
448 if (!input_exists) {
449 dest->MaybeCopyOnWrite();
450 dest->props_->node_def.add_input(new_input);
451 }
452 }
453 return AddEdge(source, kControlSlot, dest, kControlSlot);
454 }
455
RemoveControlEdge(const Edge * e)456 void Graph::RemoveControlEdge(const Edge* e) {
457 if (!e->src_->IsSource() && !e->dst_->IsSink()) {
458 e->dst_->MaybeCopyOnWrite();
459 std::string e_src_name = strings::StrCat("^", e->src_->name());
460 auto* inputs = e->dst_->props_->node_def.mutable_input();
461 for (auto it = inputs->begin(); it != inputs->end(); ++it) {
462 if (*it == e_src_name) {
463 inputs->erase(it);
464 break;
465 }
466 }
467 }
468 RemoveEdge(e);
469 }
470
UpdateEdge(Node * new_src,int new_src_index,Node * dst,int dst_index)471 Status Graph::UpdateEdge(Node* new_src, int new_src_index, Node* dst,
472 int dst_index) {
473 TF_RETURN_IF_ERROR(IsValidOutputTensor(new_src, new_src_index));
474 TF_RETURN_IF_ERROR(IsValidInputTensor(dst, dst_index));
475 const Edge* e = FindEdge(dst, dst_index);
476 if (e == nullptr) {
477 return errors::InvalidArgument("Couldn't find edge to ",
478 dst->DebugString());
479 }
480 RemoveEdge(e);
481 AddEdge(new_src, new_src_index, dst, dst_index);
482 dst->MaybeCopyOnWrite();
483 (*dst->props_->node_def.mutable_input())[dst_index] =
484 strings::StrCat(new_src->name(), ":", new_src_index);
485 return Status::OK();
486 }
487
FindEdge(const Node * dst,int index)488 const Edge* Graph::FindEdge(const Node* dst, int index) {
489 for (const Edge* e : edges_) {
490 // edges_ will contain null edges if RemoveEdge() was called.
491 if (e == nullptr) continue;
492 if (e->dst() == dst && e->dst_input() == index) {
493 return e;
494 }
495 }
496 return nullptr;
497 }
498
AddFunctionLibrary(const FunctionDefLibrary & fdef_lib)499 Status Graph::AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) {
500 // Need a new-enough consumer to support the functions we add to the graph.
501 if (fdef_lib.function_size() > 0 && versions_->min_consumer() < 12) {
502 versions_->set_min_consumer(12);
503 }
504 return ops_.AddLibrary(fdef_lib);
505 }
506
507 namespace {
508
AddInput(NodeDef * dst,StringPiece src_name,int src_slot)509 void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) {
510 if (src_slot == Graph::kControlSlot) {
511 dst->add_input(strings::StrCat("^", src_name));
512 } else if (src_slot == 0) {
513 dst->add_input(src_name.data(), src_name.size());
514 } else {
515 dst->add_input(strings::StrCat(src_name, ":", src_slot));
516 }
517 }
518
519 } // namespace
520
ToGraphDef(GraphDef * graph_def) const521 void Graph::ToGraphDef(GraphDef* graph_def) const {
522 ToGraphDefSubRange(graph_def, 0);
523 }
524
ToGraphDefDebug() const525 GraphDef Graph::ToGraphDefDebug() const {
526 GraphDef ret;
527 ToGraphDef(&ret);
528 return ret;
529 }
530
ToGraphDefSubRange(GraphDef * graph_def,int from_node_id) const531 void Graph::ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const {
532 graph_def->Clear();
533 *graph_def->mutable_versions() = versions();
534 *graph_def->mutable_library() = ops_.ToProto();
535
536 graph_def->mutable_node()->Reserve(std::max(1, num_nodes() - from_node_id));
537
538 std::vector<const Edge*>
539 inputs; // Construct this outside the loop for speed.
540 for (auto id = from_node_id; id < num_node_ids(); ++id) {
541 const Node* node = FindNodeId(id);
542 if (node == nullptr || !node->IsOp()) continue;
543 NodeDef* node_def = graph_def->add_node();
544 *node_def = node->def();
545
546 // Use the node's assigned device, if any, instead of the device requested
547 // in the NodeDef.
548 if (!node->assigned_device_name().empty()) {
549 node_def->set_device(node->assigned_device_name());
550 }
551
552 // Get the inputs for this Node. We make sure control inputs are
553 // after data inputs, as required by GraphDef.
554 inputs.clear();
555 inputs.resize(node->num_inputs(), nullptr);
556 for (const Edge* edge : node->in_edges()) {
557 if (edge->IsControlEdge()) {
558 inputs.push_back(edge);
559 } else {
560 CHECK(inputs[edge->dst_input()] == nullptr)
561 << "Edge " << edge->src()->DebugString() << ":"
562 << edge->dst()->DebugString() << " with dst_input "
563 << edge->dst_input() << " and had pre-existing input edge "
564 << inputs[edge->dst_input()]->src()->DebugString() << ":"
565 << inputs[edge->dst_input()]->dst()->DebugString();
566
567 inputs[edge->dst_input()] = edge;
568 }
569 }
570 node_def->clear_input();
571 node_def->mutable_input()->Reserve(inputs.size());
572
573 for (size_t i = 0; i < inputs.size(); ++i) {
574 const Edge* edge = inputs[i];
575 if (edge == nullptr) {
576 if (i < node->requested_inputs().size()) {
577 node_def->add_input(node->requested_inputs()[i]);
578 } else {
579 node_def->add_input("");
580 }
581 } else {
582 const Node* src = edge->src();
583 if (!src->IsOp()) continue;
584 AddInput(node_def, src->name(), edge->src_output());
585 }
586 }
587 }
588 }
589
NewName(StringPiece prefix)590 string Graph::NewName(StringPiece prefix) {
591 return strings::StrCat(prefix, "/_", name_counter_++);
592 }
593
IsValidNode(const Node * node) const594 Status Graph::IsValidNode(const Node* node) const {
595 if (node == nullptr) {
596 return errors::InvalidArgument("Node is null");
597 }
598 const int id = node->id();
599 if (id < 0) {
600 return errors::InvalidArgument("node id ", id, " is less than zero");
601 }
602 if (static_cast<size_t>(id) >= nodes_.size()) {
603 return errors::InvalidArgument(
604 "node id ", id, " is >= than number of nodes in graph ", nodes_.size());
605 }
606 if (nodes_[id] != node) {
607 return errors::InvalidArgument("Node with id ", id,
608 " is different from the passed in node. "
609 "Does it belong to a different graph?");
610 }
611 return Status::OK();
612 }
613
IsValidOutputTensor(const Node * node,int idx) const614 Status Graph::IsValidOutputTensor(const Node* node, int idx) const {
615 TF_RETURN_IF_ERROR(IsValidNode(node));
616 if (idx >= node->num_outputs()) {
617 return errors::OutOfRange("Node '", node->name(), "' (type: '",
618 node->op_def().name(),
619 "', num of outputs: ", node->num_outputs(),
620 ") does not have ", "output ", idx);
621 }
622 return Status::OK();
623 }
624
IsValidInputTensor(const Node * node,int idx) const625 Status Graph::IsValidInputTensor(const Node* node, int idx) const {
626 TF_RETURN_IF_ERROR(IsValidNode(node));
627 if (idx >= node->num_inputs()) {
628 return errors::OutOfRange("Node '", node->name(), "' (type: '",
629 node->op_def().name(),
630 "', num of inputs: ", node->num_inputs(),
631 ") does not have ", "input ", idx);
632 }
633 return Status::OK();
634 }
635
AllocateNode(std::shared_ptr<NodeProperties> props,const Node * cost_node)636 Node* Graph::AllocateNode(std::shared_ptr<NodeProperties> props,
637 const Node* cost_node) {
638 Node* node = nullptr;
639 if (free_nodes_.empty()) {
640 node = new (arena_.Alloc(sizeof(Node))) Node; // placement new
641 } else {
642 node = free_nodes_.back();
643 free_nodes_.pop_back();
644 }
645 node->graph_ = this;
646 const int id = nodes_.size();
647 int cost_id = cost_node ? cost_node->cost_id() : id;
648 node->Initialize(id, cost_id, std::move(props));
649 nodes_.push_back(node);
650 ++num_nodes_;
651 return node;
652 }
653
ReleaseNode(Node * node)654 void Graph::ReleaseNode(Node* node) {
655 TF_DCHECK_OK(IsValidNode(node)) << node->DebugString();
656 nodes_[node->id()] = nullptr;
657 free_nodes_.push_back(node);
658 --num_nodes_;
659 node->Clear();
660 }
661
662 // Ensures that 'device_name' is present in the device name table, and returns
663 // the index of that device name. The index is stable, and can be used in
664 // calls to Node::set_assigned_device_name_index().
InternDeviceName(const string & device_name)665 int Graph::InternDeviceName(const string& device_name) {
666 // Special case, very common. Also, this allows us to use a single map
667 // lookup below, instead of two. The 'if (index_cell > 0)' test below
668 // relies on this check.
669 if (device_name.empty()) {
670 return 0;
671 }
672
673 int& index_cell = device_names_map_[device_name];
674 if (index_cell > 0) {
675 return index_cell;
676 }
677
678 const int index = device_names_map_.size();
679 index_cell = index;
680 device_names_.push_back(device_name);
681 return index;
682 }
683
AddWhileContext(StringPiece frame_name,std::vector<Node * > enter_nodes,std::vector<Node * > exit_nodes,OutputTensor cond_output,std::vector<OutputTensor> body_inputs,std::vector<OutputTensor> body_outputs,WhileContext ** result)684 Status Graph::AddWhileContext(StringPiece frame_name,
685 std::vector<Node*> enter_nodes,
686 std::vector<Node*> exit_nodes,
687 OutputTensor cond_output,
688 std::vector<OutputTensor> body_inputs,
689 std::vector<OutputTensor> body_outputs,
690 WhileContext** result) {
691 auto pair = while_ctxs_.insert(std::pair<string, WhileContext>(
692 frame_name.ToString(),
693 WhileContext(frame_name, std::move(enter_nodes), std::move(exit_nodes),
694 cond_output, std::move(body_inputs),
695 std::move(body_outputs))));
696 if (!pair.second) {
697 *result = nullptr;
698 return errors::InvalidArgument("WhileContext with frame name '", frame_name,
699 "' already exists");
700 }
701 *result = &pair.first->second;
702 return Status::OK();
703 }
704
DebugString() const705 string Edge::DebugString() const {
706 return strings::Printf("[id=%d %s:%d -> %s:%d]", id_, src_->name().c_str(),
707 src_output_, dst_->name().c_str(), dst_input_);
708 }
709
710 } // namespace tensorflow
711