• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/graph/graph.h"
17 
18 #include <vector>
19 #include "tensorflow/core/framework/graph.pb.h"
20 #include "tensorflow/core/framework/node_def.pb.h"
21 #include "tensorflow/core/framework/node_def_util.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/versions.pb.h"
24 #include "tensorflow/core/graph/while_context.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/lib/gtl/map_util.h"
27 #include "tensorflow/core/lib/strings/strcat.h"
28 #include "tensorflow/core/lib/strings/stringprintf.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/public/version.h"
31 
32 namespace tensorflow {
33 
34 const int Graph::kControlSlot = -1;
35 
36 class NodeProperties {
37  public:
NodeProperties(const OpDef * op_def,const NodeDef & node_def,const DataTypeSlice inputs,const DataTypeSlice outputs)38   NodeProperties(const OpDef* op_def, const NodeDef& node_def,
39                  const DataTypeSlice inputs, const DataTypeSlice outputs)
40       : op_def(op_def),
41         node_def(node_def),
42         input_types(inputs.begin(), inputs.end()),
43         output_types(outputs.begin(), outputs.end()) {}
44 
45   const OpDef* op_def;  // not owned
46   NodeDef node_def;
47   const DataTypeVector input_types;
48   const DataTypeVector output_types;
49 };
50 
51 // Node
52 
53 #define REF_CLASS(key, value) \
54   {key, value}, { "Ref" key, value }
55 
56 const std::unordered_map<string, Node::NodeClass>& Node::kNodeClassTable =
57     *new std::unordered_map<string, Node::NodeClass>({
58         // Keep in same order as NodeClass values
59         REF_CLASS("Switch", NC_SWITCH),
60         REF_CLASS("Merge", NC_MERGE),
61         REF_CLASS("Enter", NC_ENTER),
62         REF_CLASS("Exit", NC_EXIT),
63         REF_CLASS("NextIteration", NC_NEXT_ITERATION),
64         {"LoopCond", NC_LOOP_COND},
65         {"ControlTrigger", NC_CONTROL_TRIGGER},
66         {"_Send", NC_SEND},
67         {"_HostSend", NC_HOST_SEND},
68         {"_Recv", NC_RECV},
69         {"_HostRecv", NC_HOST_RECV},
70         {"Const", NC_CONSTANT},
71         {"HostConst", NC_CONSTANT},
72         {"Variable", NC_VARIABLE},
73         {"VariableV2", NC_VARIABLE},
74         REF_CLASS("Identity", NC_IDENTITY),
75         {"GetSessionHandle", NC_GET_SESSION_HANDLE},
76         {"GetSessionHandleV2", NC_GET_SESSION_HANDLE},
77         {"GetSessionTensor", NC_GET_SESSION_TENSOR},
78         {"DeleteSessionTensor", NC_DELETE_SESSION_TENSOR},
79         {"Size", NC_METADATA},
80         {"Shape", NC_METADATA},
81         {"Rank", NC_METADATA},
82     });
83 
84 #undef REF_CLASS
85 
GetNodeClassForOp(const string & ts)86 Node::NodeClass Node::GetNodeClassForOp(const string& ts) {
87   auto it = kNodeClassTable.find(ts);
88   if (it != kNodeClassTable.end()) {
89     return it->second;
90   } else {
91     return NC_OTHER;
92   }
93 }
94 
DebugString() const95 string Node::DebugString() const {
96   string ret = strings::StrCat("{name:'", name(), "' id:", id_);
97   if (IsSource()) {
98     strings::StrAppend(&ret, " source}");
99   } else if (IsSink()) {
100     strings::StrAppend(&ret, " sink}");
101   } else {
102     strings::StrAppend(&ret, " op device:");
103     strings::StrAppend(&ret, "{", assigned_device_name(), "}");
104     strings::StrAppend(&ret, " def:{", SummarizeNode(*this), "}}");
105   }
106   return ret;
107 }
108 
Node()109 Node::Node()
110     : id_(-1),
111       cost_id_(-1),
112       class_(NC_UNINITIALIZED),
113       props_(nullptr),
114       assigned_device_name_index_(0),
115       while_ctx_(nullptr) {}
116 
Initialize(int id,int cost_id,std::shared_ptr<NodeProperties> props)117 void Node::Initialize(int id, int cost_id,
118                       std::shared_ptr<NodeProperties> props) {
119   DCHECK_EQ(id_, -1);
120   DCHECK(in_edges_.empty());
121   DCHECK(out_edges_.empty());
122   id_ = id;
123   cost_id_ = cost_id;
124 
125   props_ = std::move(props);
126   // Initialize the class_ based on the type string
127   class_ = GetNodeClassForOp(props_->node_def.op());
128 }
129 
Clear()130 void Node::Clear() {
131   in_edges_.clear();
132   out_edges_.clear();
133   id_ = -1;
134   cost_id_ = -1;
135   class_ = NC_UNINITIALIZED;
136   props_.reset();
137   assigned_device_name_index_ = 0;
138 }
139 
name() const140 const string& Node::name() const { return props_->node_def.name(); }
type_string() const141 const string& Node::type_string() const { return props_->node_def.op(); }
def() const142 const NodeDef& Node::def() const { return props_->node_def; }
op_def() const143 const OpDef& Node::op_def() const { return *props_->op_def; }
144 
num_inputs() const145 int32 Node::num_inputs() const { return props_->input_types.size(); }
input_type(int32 i) const146 DataType Node::input_type(int32 i) const { return props_->input_types[i]; }
input_types() const147 const DataTypeVector& Node::input_types() const { return props_->input_types; }
148 
num_outputs() const149 int32 Node::num_outputs() const { return props_->output_types.size(); }
output_type(int32 o) const150 DataType Node::output_type(int32 o) const { return props_->output_types[o]; }
output_types() const151 const DataTypeVector& Node::output_types() const {
152   return props_->output_types;
153 }
154 
attrs() const155 AttrSlice Node::attrs() const { return AttrSlice(def()); }
156 
requested_inputs() const157 const protobuf::RepeatedPtrField<string>& Node::requested_inputs() const {
158   return def().input();
159 }
160 
requested_device() const161 const string& Node::requested_device() const { return def().device(); }
162 
out_nodes() const163 gtl::iterator_range<NeighborIter> Node::out_nodes() const {
164   return gtl::make_range(NeighborIter(out_edges_.begin(), false),
165                          NeighborIter(out_edges_.end(), false));
166 }
167 
in_nodes() const168 gtl::iterator_range<NeighborIter> Node::in_nodes() const {
169   return gtl::make_range(NeighborIter(in_edges_.begin(), true),
170                          NeighborIter(in_edges_.end(), true));
171 }
172 
MaybeCopyOnWrite()173 void Node::MaybeCopyOnWrite() {
174   // NodeProperties may be shared between Nodes. Make a copy if so.
175   if (!props_.unique()) {
176     props_ = std::make_shared<NodeProperties>(*props_);
177   }
178 }
179 
AddAttrHelper(const string & name)180 AttrValue* Node::AddAttrHelper(const string& name) {
181   MaybeCopyOnWrite();
182   return &((*props_->node_def.mutable_attr())[name]);
183 }
184 
ClearAttr(const string & name)185 void Node::ClearAttr(const string& name) {
186   MaybeCopyOnWrite();
187   (*props_->node_def.mutable_attr()).erase(name);
188 }
189 
set_requested_device(const string & device)190 void Node::set_requested_device(const string& device) {
191   MaybeCopyOnWrite();
192   props_->node_def.set_device(device);
193 }
194 
input_edge(int idx,const Edge ** e) const195 Status Node::input_edge(int idx, const Edge** e) const {
196   if (idx < 0 || idx >= num_inputs()) {
197     return errors::InvalidArgument("Invalid input_edge index: ", idx, ", Node ",
198                                    name(), " only has ", num_inputs(),
199                                    " inputs.");
200   }
201 
202   // This does a linear search over the edges.  In the common case,
203   // the number of elements is small enough that this search isn't
204   // expensive.  Should it become a bottleneck, one can make an
205   // optimization where, if the number of edges is small, we use
206   // linear iteration, and if the number of edges is large, we perform
207   // an indexing step during construction that keeps an array of Edges
208   // indexed by pointer.  This would keep the size of each Node small
209   // in the common case but make this function faster when the number
210   // of edges is large.
211   for (const Edge* edge : in_edges()) {
212     if (edge->dst_input() == idx) {
213       *e = edge;
214       return Status::OK();
215     }
216   }
217 
218   return errors::NotFound("Could not find input edge ", idx, " for ", name());
219 }
220 
221 // Returns a vector of the non-control input edges to a node, indexed by ID.
input_edges(std::vector<const Edge * > * input_edges) const222 Status Node::input_edges(std::vector<const Edge*>* input_edges) const {
223   input_edges->clear();
224   input_edges->resize(num_inputs(), nullptr);
225 
226   for (const Edge* edge : in_edges()) {
227     if (edge->IsControlEdge()) continue;
228     if (edge->dst_input() < 0 || edge->dst_input() >= num_inputs()) {
229       return errors::Internal("Invalid edge input number ", edge->dst_input());
230     }
231     if ((*input_edges)[edge->dst_input()] != nullptr) {
232       return errors::Internal("Duplicate edge input number: ",
233                               edge->dst_input());
234     }
235     (*input_edges)[edge->dst_input()] = edge;
236   }
237 
238   for (int i = 0; i < num_inputs(); ++i) {
239     if ((*input_edges)[i] == nullptr) {
240       return errors::InvalidArgument("Missing edge input number: ", i);
241     }
242   }
243   return Status::OK();
244 }
245 
input_node(int idx,Node ** n) const246 Status Node::input_node(int idx, Node** n) const {
247   const Edge* e;
248   TF_RETURN_IF_ERROR(input_edge(idx, &e));
249   if (e == nullptr) {
250     *n = nullptr;
251   } else {
252     *n = e->src();
253   }
254   return Status::OK();
255 }
256 
input_node(int idx,const Node ** const_n) const257 Status Node::input_node(int idx, const Node** const_n) const {
258   Node* n;
259   TF_RETURN_IF_ERROR(input_node(idx, &n));
260   *const_n = n;
261   return Status::OK();
262 }
263 
264 // Graph
265 
Graph(const OpRegistryInterface * ops)266 Graph::Graph(const OpRegistryInterface* ops)
267     : ops_(ops, FunctionDefLibrary()),
268       versions_(new VersionDef),
269       arena_(8 << 10 /* 8kB */) {
270   versions_->set_producer(TF_GRAPH_DEF_VERSION);
271   versions_->set_min_consumer(TF_GRAPH_DEF_VERSION_MIN_CONSUMER);
272 
273   // Initialize the name interning table for assigned_device_name.
274   device_names_.push_back("");
275   DCHECK_EQ(0, InternDeviceName(""));
276 
277   // Source and sink have no endpoints, just control edges.
278   NodeDef def;
279   def.set_name("_SOURCE");
280   def.set_op("NoOp");
281   Status status;
282   Node* source = AddNode(def, &status);
283   TF_CHECK_OK(status);
284   CHECK_EQ(source->id(), kSourceId);
285 
286   def.set_name("_SINK");
287   Node* sink = AddNode(def, &status);
288   TF_CHECK_OK(status);
289   CHECK_EQ(sink->id(), kSinkId);
290 
291   AddControlEdge(source, sink);
292 }
293 
Graph(const FunctionLibraryDefinition & flib_def)294 Graph::Graph(const FunctionLibraryDefinition& flib_def)
295     : Graph(flib_def.default_registry()) {
296   // Need a new-enough consumer to support the functions we add to the graph.
297   if (flib_def.ToProto().function_size() > 0 &&
298       versions_->min_consumer() < 12) {
299     versions_->set_min_consumer(12);
300   }
301   Status s = ops_.AddLibrary(flib_def);
302   CHECK(s.ok()) << s.error_message();
303 }
304 
~Graph()305 Graph::~Graph() {
306   // Manually call the destructors for all the Nodes we constructed using
307   // placement new.
308   for (Node* node : nodes_) {
309     if (node != nullptr) {
310       node->~Node();
311     }
312   }
313   for (Node* node : free_nodes_) {
314     node->~Node();
315   }
316   // Edges have no destructor, and we arena-allocated them, so no need to
317   // destroy them.
318 }
319 
versions() const320 const VersionDef& Graph::versions() const { return *versions_; }
set_versions(const VersionDef & versions)321 void Graph::set_versions(const VersionDef& versions) { *versions_ = versions; }
322 
AddNode(const NodeDef & node_def,Status * status)323 Node* Graph::AddNode(const NodeDef& node_def, Status* status) {
324   const OpDef* op_def;
325   status->Update(ops_.LookUpOpDef(node_def.op(), &op_def));
326   if (!status->ok()) return nullptr;
327 
328   DataTypeVector inputs;
329   DataTypeVector outputs;
330   status->Update(InOutTypesForNode(node_def, *op_def, &inputs, &outputs));
331   if (!status->ok()) {
332     *status = AttachDef(*status, node_def);
333     return nullptr;
334   }
335 
336   Node* node = AllocateNode(
337       std::make_shared<NodeProperties>(op_def, node_def, inputs, outputs),
338       nullptr);
339   return node;
340 }
341 
CopyNode(Node * node)342 Node* Graph::CopyNode(Node* node) {
343   DCHECK(!node->IsSource());
344   DCHECK(!node->IsSink());
345   Node* copy = AllocateNode(node->props_, node);
346   copy->set_assigned_device_name(node->assigned_device_name());
347 
348   // Since the OpDef of a function may be owned by the Graph that owns 'node',
349   // relookup the OpDef in the target graph. If it differs, then clone the
350   // node properties with the updated OpDef.
351   const OpDef* op_def;
352   TF_CHECK_OK(ops_.LookUpOpDef(node->type_string(), &op_def));
353   if (op_def != node->props_->op_def) {
354     copy->MaybeCopyOnWrite();
355     copy->props_->op_def = op_def;
356   }
357 
358   return copy;
359 }
360 
RemoveNode(Node * node)361 void Graph::RemoveNode(Node* node) {
362   TF_DCHECK_OK(IsValidNode(node)) << node->DebugString();
363   DCHECK(!node->IsSource());
364   DCHECK(!node->IsSink());
365 
366   // Remove any edges involving this node.
367   while (!node->in_edges_.empty()) {
368     RemoveEdge(*node->in_edges_.begin());
369   }
370   while (!node->out_edges_.empty()) {
371     RemoveEdge(*node->out_edges_.begin());
372   }
373   ReleaseNode(node);
374 }
375 
AddEdge(Node * source,int x,Node * dest,int y)376 const Edge* Graph::AddEdge(Node* source, int x, Node* dest, int y) {
377   TF_DCHECK_OK(IsValidNode(source)) << source->DebugString();
378   TF_DCHECK_OK(IsValidNode(dest)) << dest->DebugString();
379 
380   // source/sink must only be linked via control slots, and
381   // control slots must only be linked to control slots.
382   if (source == source_node() || dest == sink_node() || x == kControlSlot ||
383       y == kControlSlot) {
384     DCHECK_EQ(x, kControlSlot) << source->DebugString();
385     DCHECK_EQ(y, kControlSlot) << dest->DebugString();
386   }
387 
388   Edge* e = nullptr;
389   if (free_edges_.empty()) {
390     e = new (arena_.Alloc(sizeof(Edge))) Edge;  // placement new
391   } else {
392     e = free_edges_.back();
393     free_edges_.pop_back();
394   }
395   e->id_ = edges_.size();
396   e->src_ = source;
397   e->dst_ = dest;
398   e->src_output_ = x;
399   e->dst_input_ = y;
400   CHECK(source->out_edges_.insert(e).second);
401   CHECK(dest->in_edges_.insert(e).second);
402   edges_.push_back(e);
403   ++num_edges_;
404   return e;
405 }
406 
RemoveEdge(const Edge * e)407 void Graph::RemoveEdge(const Edge* e) {
408   TF_DCHECK_OK(IsValidNode(e->src_)) << e->src_->DebugString();
409   TF_DCHECK_OK(IsValidNode(e->dst_)) << e->dst_->DebugString();
410   CHECK_EQ(e->src_->out_edges_.erase(e), size_t{1});
411   CHECK_EQ(e->dst_->in_edges_.erase(e), size_t{1});
412   CHECK_EQ(e, edges_[e->id_]);
413   CHECK_GT(num_edges_, 0);
414 
415   edges_[e->id_] = nullptr;
416 
417   Edge* del = const_cast<Edge*>(e);
418   del->src_ = nullptr;
419   del->dst_ = nullptr;
420   del->id_ = -1;
421   del->src_output_ = kControlSlot - 1;
422   del->dst_input_ = kControlSlot - 1;
423   free_edges_.push_back(del);
424   --num_edges_;
425 }
426 
AddControlEdge(Node * source,Node * dest,bool allow_duplicates)427 const Edge* Graph::AddControlEdge(Node* source, Node* dest,
428                                   bool allow_duplicates) {
429   if (!allow_duplicates) {
430     for (const Edge* edge : dest->in_edges()) {
431       if (edge->IsControlEdge() && edge->src() == source) {
432         // The requested edge already exists.
433         return nullptr;
434       }
435     }
436   }
437   // Modify dest's NodeDef if necessary.
438   if (!source->IsSource() && !dest->IsSink() && !allow_duplicates) {
439     // Check if this input is already in dest's NodeDef.
440     const string new_input = strings::StrCat("^", source->name());
441     bool input_exists = false;
442     for (const string& input : dest->props_->node_def.input()) {
443       if (input == new_input) {
444         input_exists = true;
445         break;
446       }
447     }
448     if (!input_exists) {
449       dest->MaybeCopyOnWrite();
450       dest->props_->node_def.add_input(new_input);
451     }
452   }
453   return AddEdge(source, kControlSlot, dest, kControlSlot);
454 }
455 
RemoveControlEdge(const Edge * e)456 void Graph::RemoveControlEdge(const Edge* e) {
457   if (!e->src_->IsSource() && !e->dst_->IsSink()) {
458     e->dst_->MaybeCopyOnWrite();
459     std::string e_src_name = strings::StrCat("^", e->src_->name());
460     auto* inputs = e->dst_->props_->node_def.mutable_input();
461     for (auto it = inputs->begin(); it != inputs->end(); ++it) {
462       if (*it == e_src_name) {
463         inputs->erase(it);
464         break;
465       }
466     }
467   }
468   RemoveEdge(e);
469 }
470 
UpdateEdge(Node * new_src,int new_src_index,Node * dst,int dst_index)471 Status Graph::UpdateEdge(Node* new_src, int new_src_index, Node* dst,
472                          int dst_index) {
473   TF_RETURN_IF_ERROR(IsValidOutputTensor(new_src, new_src_index));
474   TF_RETURN_IF_ERROR(IsValidInputTensor(dst, dst_index));
475   const Edge* e = FindEdge(dst, dst_index);
476   if (e == nullptr) {
477     return errors::InvalidArgument("Couldn't find edge to ",
478                                    dst->DebugString());
479   }
480   RemoveEdge(e);
481   AddEdge(new_src, new_src_index, dst, dst_index);
482   dst->MaybeCopyOnWrite();
483   (*dst->props_->node_def.mutable_input())[dst_index] =
484       strings::StrCat(new_src->name(), ":", new_src_index);
485   return Status::OK();
486 }
487 
FindEdge(const Node * dst,int index)488 const Edge* Graph::FindEdge(const Node* dst, int index) {
489   for (const Edge* e : edges_) {
490     // edges_ will contain null edges if RemoveEdge() was called.
491     if (e == nullptr) continue;
492     if (e->dst() == dst && e->dst_input() == index) {
493       return e;
494     }
495   }
496   return nullptr;
497 }
498 
AddFunctionLibrary(const FunctionDefLibrary & fdef_lib)499 Status Graph::AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) {
500   // Need a new-enough consumer to support the functions we add to the graph.
501   if (fdef_lib.function_size() > 0 && versions_->min_consumer() < 12) {
502     versions_->set_min_consumer(12);
503   }
504   return ops_.AddLibrary(fdef_lib);
505 }
506 
507 namespace {
508 
AddInput(NodeDef * dst,StringPiece src_name,int src_slot)509 void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) {
510   if (src_slot == Graph::kControlSlot) {
511     dst->add_input(strings::StrCat("^", src_name));
512   } else if (src_slot == 0) {
513     dst->add_input(src_name.data(), src_name.size());
514   } else {
515     dst->add_input(strings::StrCat(src_name, ":", src_slot));
516   }
517 }
518 
519 }  // namespace
520 
ToGraphDef(GraphDef * graph_def) const521 void Graph::ToGraphDef(GraphDef* graph_def) const {
522   ToGraphDefSubRange(graph_def, 0);
523 }
524 
ToGraphDefDebug() const525 GraphDef Graph::ToGraphDefDebug() const {
526   GraphDef ret;
527   ToGraphDef(&ret);
528   return ret;
529 }
530 
ToGraphDefSubRange(GraphDef * graph_def,int from_node_id) const531 void Graph::ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const {
532   graph_def->Clear();
533   *graph_def->mutable_versions() = versions();
534   *graph_def->mutable_library() = ops_.ToProto();
535 
536   graph_def->mutable_node()->Reserve(std::max(1, num_nodes() - from_node_id));
537 
538   std::vector<const Edge*>
539       inputs;  // Construct this outside the loop for speed.
540   for (auto id = from_node_id; id < num_node_ids(); ++id) {
541     const Node* node = FindNodeId(id);
542     if (node == nullptr || !node->IsOp()) continue;
543     NodeDef* node_def = graph_def->add_node();
544     *node_def = node->def();
545 
546     // Use the node's assigned device, if any, instead of the device requested
547     // in the NodeDef.
548     if (!node->assigned_device_name().empty()) {
549       node_def->set_device(node->assigned_device_name());
550     }
551 
552     // Get the inputs for this Node.  We make sure control inputs are
553     // after data inputs, as required by GraphDef.
554     inputs.clear();
555     inputs.resize(node->num_inputs(), nullptr);
556     for (const Edge* edge : node->in_edges()) {
557       if (edge->IsControlEdge()) {
558         inputs.push_back(edge);
559       } else {
560         CHECK(inputs[edge->dst_input()] == nullptr)
561             << "Edge " << edge->src()->DebugString() << ":"
562             << edge->dst()->DebugString() << " with dst_input "
563             << edge->dst_input() << " and had pre-existing input edge "
564             << inputs[edge->dst_input()]->src()->DebugString() << ":"
565             << inputs[edge->dst_input()]->dst()->DebugString();
566 
567         inputs[edge->dst_input()] = edge;
568       }
569     }
570     node_def->clear_input();
571     node_def->mutable_input()->Reserve(inputs.size());
572 
573     for (size_t i = 0; i < inputs.size(); ++i) {
574       const Edge* edge = inputs[i];
575       if (edge == nullptr) {
576         if (i < node->requested_inputs().size()) {
577           node_def->add_input(node->requested_inputs()[i]);
578         } else {
579           node_def->add_input("");
580         }
581       } else {
582         const Node* src = edge->src();
583         if (!src->IsOp()) continue;
584         AddInput(node_def, src->name(), edge->src_output());
585       }
586     }
587   }
588 }
589 
NewName(StringPiece prefix)590 string Graph::NewName(StringPiece prefix) {
591   return strings::StrCat(prefix, "/_", name_counter_++);
592 }
593 
IsValidNode(const Node * node) const594 Status Graph::IsValidNode(const Node* node) const {
595   if (node == nullptr) {
596     return errors::InvalidArgument("Node is null");
597   }
598   const int id = node->id();
599   if (id < 0) {
600     return errors::InvalidArgument("node id ", id, " is less than zero");
601   }
602   if (static_cast<size_t>(id) >= nodes_.size()) {
603     return errors::InvalidArgument(
604         "node id ", id, " is >= than number of nodes in graph ", nodes_.size());
605   }
606   if (nodes_[id] != node) {
607     return errors::InvalidArgument("Node with id ", id,
608                                    " is different from the passed in node. "
609                                    "Does it belong to a different graph?");
610   }
611   return Status::OK();
612 }
613 
IsValidOutputTensor(const Node * node,int idx) const614 Status Graph::IsValidOutputTensor(const Node* node, int idx) const {
615   TF_RETURN_IF_ERROR(IsValidNode(node));
616   if (idx >= node->num_outputs()) {
617     return errors::OutOfRange("Node '", node->name(), "' (type: '",
618                               node->op_def().name(),
619                               "', num of outputs: ", node->num_outputs(),
620                               ") does not have ", "output ", idx);
621   }
622   return Status::OK();
623 }
624 
IsValidInputTensor(const Node * node,int idx) const625 Status Graph::IsValidInputTensor(const Node* node, int idx) const {
626   TF_RETURN_IF_ERROR(IsValidNode(node));
627   if (idx >= node->num_inputs()) {
628     return errors::OutOfRange("Node '", node->name(), "' (type: '",
629                               node->op_def().name(),
630                               "', num of inputs: ", node->num_inputs(),
631                               ") does not have ", "input ", idx);
632   }
633   return Status::OK();
634 }
635 
AllocateNode(std::shared_ptr<NodeProperties> props,const Node * cost_node)636 Node* Graph::AllocateNode(std::shared_ptr<NodeProperties> props,
637                           const Node* cost_node) {
638   Node* node = nullptr;
639   if (free_nodes_.empty()) {
640     node = new (arena_.Alloc(sizeof(Node))) Node;  // placement new
641   } else {
642     node = free_nodes_.back();
643     free_nodes_.pop_back();
644   }
645   node->graph_ = this;
646   const int id = nodes_.size();
647   int cost_id = cost_node ? cost_node->cost_id() : id;
648   node->Initialize(id, cost_id, std::move(props));
649   nodes_.push_back(node);
650   ++num_nodes_;
651   return node;
652 }
653 
ReleaseNode(Node * node)654 void Graph::ReleaseNode(Node* node) {
655   TF_DCHECK_OK(IsValidNode(node)) << node->DebugString();
656   nodes_[node->id()] = nullptr;
657   free_nodes_.push_back(node);
658   --num_nodes_;
659   node->Clear();
660 }
661 
662 // Ensures that 'device_name' is present in the device name table, and returns
663 // the index of that device name. The index is stable, and can be used in
664 // calls to Node::set_assigned_device_name_index().
InternDeviceName(const string & device_name)665 int Graph::InternDeviceName(const string& device_name) {
666   // Special case, very common.  Also, this allows us to use a single map
667   // lookup below, instead of two.  The 'if (index_cell > 0)' test below
668   // relies on this check.
669   if (device_name.empty()) {
670     return 0;
671   }
672 
673   int& index_cell = device_names_map_[device_name];
674   if (index_cell > 0) {
675     return index_cell;
676   }
677 
678   const int index = device_names_map_.size();
679   index_cell = index;
680   device_names_.push_back(device_name);
681   return index;
682 }
683 
AddWhileContext(StringPiece frame_name,std::vector<Node * > enter_nodes,std::vector<Node * > exit_nodes,OutputTensor cond_output,std::vector<OutputTensor> body_inputs,std::vector<OutputTensor> body_outputs,WhileContext ** result)684 Status Graph::AddWhileContext(StringPiece frame_name,
685                               std::vector<Node*> enter_nodes,
686                               std::vector<Node*> exit_nodes,
687                               OutputTensor cond_output,
688                               std::vector<OutputTensor> body_inputs,
689                               std::vector<OutputTensor> body_outputs,
690                               WhileContext** result) {
691   auto pair = while_ctxs_.insert(std::pair<string, WhileContext>(
692       frame_name.ToString(),
693       WhileContext(frame_name, std::move(enter_nodes), std::move(exit_nodes),
694                    cond_output, std::move(body_inputs),
695                    std::move(body_outputs))));
696   if (!pair.second) {
697     *result = nullptr;
698     return errors::InvalidArgument("WhileContext with frame name '", frame_name,
699                                    "' already exists");
700   }
701   *result = &pair.first->second;
702   return Status::OK();
703 }
704 
DebugString() const705 string Edge::DebugString() const {
706   return strings::Printf("[id=%d %s:%d -> %s:%d]", id_, src_->name().c_str(),
707                          src_output_, dst_->name().c_str(), dst_input_);
708 }
709 
710 }  // namespace tensorflow
711