• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // A Graph describes a set of computations that are to be
17 // performed, as well as the dependencies between those
18 // computations. The basic model is a DAG (directed acyclic graph) with
19 // * internal nodes representing computational operations to be performed;
20 // * edges represent dependencies, indicating the target may only be
21 //   executed once the source has completed; and
22 // * predefined "source" (start) and "sink" (finish) nodes -- the source
23 //   should be the only node that doesn't depend on anything, and the sink
24 //   should be the only node that nothing depends on.
25 //
26 // Note: Node ids are intended to be relatively dense in the
27 // 0..max_id range, but there may be gaps since ids won't be reused.
28 //
29 // Note: Some dependencies between operations are due to one operation
30 // consuming the output of another. In fact operations can produce
31 // multiple outputs and consume multiple inputs, and some
32 // optimizations will care about which specific outputs are connected
33 // to which specific inputs.  We therefore represent data dependency
34 // between output O of layer A and input I of layer B using
35 // "input index" and "output index" labels per edge.
36 
37 #ifndef TENSORFLOW_GRAPH_GRAPH_H_
38 #define TENSORFLOW_GRAPH_GRAPH_H_
39 
40 #include <functional>
41 #include <string>
42 #include <vector>
43 #include "tensorflow/core/framework/function.h"
44 #include "tensorflow/core/framework/op.h"
45 #include "tensorflow/core/framework/types.h"
46 #include "tensorflow/core/graph/edgeset.h"
47 #include "tensorflow/core/lib/core/arena.h"
48 #include "tensorflow/core/lib/core/refcount.h"
49 #include "tensorflow/core/lib/core/status.h"
50 #include "tensorflow/core/lib/gtl/iterator_range.h"
51 #include "tensorflow/core/platform/logging.h"
52 #include "tensorflow/core/platform/macros.h"
53 #include "tensorflow/core/platform/types.h"
54 
55 namespace tensorflow {
56 
57 class Edge;
58 class EdgeSetTest;
59 class Graph;
60 class GraphDef;
61 class Node;
62 class VersionDef;
63 class WhileContext;
64 
65 class NeighborIter;    // Declared below
66 class NodeIter;        // Declared below
67 class NodeProperties;  // Defined in .cc
68 
69 class Node {
70  public:
71   string DebugString() const;
id()72   int id() const { return id_; }
cost_id()73   int cost_id() const { return cost_id_; }
74   const string& name() const;
75   const string& type_string() const;
76 
77   // def() provides the NodeDef the user supplied, but the specifics
78   // of this Node may have changed due to placement, optimization, etc.
79   // In particular:
80   // * def().name() will match name();
81   // * def().op() will match type_string() and op_def().name();
82   // * def().input() is not reliable, use "in_edges()" below instead;
83   // * def().device() is the "user's requested device" and may not match
84   //   the actual assigned device, see assigned_device_name() below;
85   // * def().attr() is authoritative.
86   // TODO(irving): Replace with NodeInfo.
87   const NodeDef& def() const;
88   const OpDef& op_def() const;
89 
90   // input and output types
91   int32 num_inputs() const;
92   DataType input_type(int32 i) const;
93   const DataTypeVector& input_types() const;
94 
95   int32 num_outputs() const;
96   DataType output_type(int32 o) const;
97   const DataTypeVector& output_types() const;
98 
99   // The device requested by the user.  For the actual assigned device,
100   // use assigned_device_name() below.
101   const string& requested_device() const;
102 
103   // This changes the user requested device but not necessarily the device that
104   // on which the operation will run.
105   void set_requested_device(const string& device);
106 
107   // This gives the device the runtime has assigned this node to.  If
108   // you want the device the user requested, use def().device() instead.
109   // TODO(josh11b): Validate that the assigned_device, if not empty:
110   // fully specifies a device, and satisfies def().device().
111   // TODO(josh11b): Move assigned_device_name outside of Node into a
112   // NodeId->DeviceName map.
113   const string& assigned_device_name() const;
114   void set_assigned_device_name(const string& device_name);
has_assigned_device_name()115   bool has_assigned_device_name() const {
116     return assigned_device_name_index_ > 0;
117   }
assigned_device_name_index()118   int assigned_device_name_index() const { return assigned_device_name_index_; }
119   void set_assigned_device_name_index(int index);
120 
121   // Read only access to attributes
122   AttrSlice attrs() const;
123 
124   // Inputs requested by the NodeDef.  For the actual inputs, use in_edges.
125   const protobuf::RepeatedPtrField<string>& requested_inputs() const;
126 
127   // Get the neighboring nodes via edges either in or out of this node.
128   gtl::iterator_range<NeighborIter> in_nodes() const;
129   gtl::iterator_range<NeighborIter> out_nodes() const;
in_edges()130   const EdgeSet& in_edges() const { return in_edges_; }
out_edges()131   const EdgeSet& out_edges() const { return out_edges_; }
132 
133   // Node type helpers.
IsSource()134   bool IsSource() const { return id() == 0; }
IsSink()135   bool IsSink() const { return id() == 1; }
136   // Anything other than the special Source & Sink nodes.
IsOp()137   bool IsOp() const { return id() > 1; }
138 
139   // Node class helpers
IsSwitch()140   bool IsSwitch() const { return class_ == NC_SWITCH; }
IsMerge()141   bool IsMerge() const { return class_ == NC_MERGE; }
IsEnter()142   bool IsEnter() const { return class_ == NC_ENTER; }
IsExit()143   bool IsExit() const { return class_ == NC_EXIT; }
IsNextIteration()144   bool IsNextIteration() const { return class_ == NC_NEXT_ITERATION; }
IsLoopCond()145   bool IsLoopCond() const { return class_ == NC_LOOP_COND; }
IsControlTrigger()146   bool IsControlTrigger() const { return class_ == NC_CONTROL_TRIGGER; }
IsSend()147   bool IsSend() const { return class_ == NC_SEND || class_ == NC_HOST_SEND; }
IsRecv()148   bool IsRecv() const { return class_ == NC_RECV || class_ == NC_HOST_RECV; }
IsConstant()149   bool IsConstant() const { return class_ == NC_CONSTANT; }
IsVariable()150   bool IsVariable() const { return class_ == NC_VARIABLE; }
IsIdentity()151   bool IsIdentity() const { return class_ == NC_IDENTITY; }
IsGetSessionHandle()152   bool IsGetSessionHandle() const { return class_ == NC_GET_SESSION_HANDLE; }
IsGetSessionTensor()153   bool IsGetSessionTensor() const { return class_ == NC_GET_SESSION_TENSOR; }
IsDeleteSessionTensor()154   bool IsDeleteSessionTensor() const {
155     return class_ == NC_DELETE_SESSION_TENSOR;
156   }
IsControlFlow()157   bool IsControlFlow() const {
158     return (class_ != NC_OTHER) &&  // Fast path
159            (IsSwitch() || IsMerge() || IsEnter() || IsExit() ||
160             IsNextIteration());
161   }
IsHostSend()162   bool IsHostSend() const { return class_ == NC_HOST_SEND; }
IsHostRecv()163   bool IsHostRecv() const { return class_ == NC_HOST_RECV; }
164 
IsMetadata()165   bool IsMetadata() const { return class_ == NC_METADATA; }
166 
167   template <typename T>
AddAttr(const string & name,const T & val)168   void AddAttr(const string& name, const T& val) {
169     SetAttrValue(val, AddAttrHelper(name));
170   }
171 
172   void ClearAttr(const string& name);
173 
174   // Returns into '*e' the edge connecting to the 'idx' input of this Node.
175   Status input_edge(int idx, const Edge** e) const;
176 
177   // Returns into '*edges' the input data edges of this Node, indexed by input
178   // number. Does not return control edges.
179   Status input_edges(std::vector<const Edge*>* edges) const;
180 
181   // Returns into '*n' the node that has an output connected to the
182   // 'idx' input of this Node.
183   Status input_node(int idx, const Node** n) const;
184   Status input_node(int idx, Node** n) const;
185 
while_ctx()186   WhileContext* while_ctx() const { return while_ctx_; }
set_while_ctx(WhileContext * while_ctx)187   void set_while_ctx(WhileContext* while_ctx) {
188     DCHECK(IsExit());
189     DCHECK(while_ctx_ == nullptr);
190     while_ctx_ = while_ctx;
191   }
192 
193  private:
194   friend class Graph;
195   Node();
196 
properties()197   NodeProperties* properties() const { return props_.get(); }
198 
199   void Initialize(int id, int cost_id, std::shared_ptr<NodeProperties> props);
200 
201   // Releases memory from props_, in addition to restoring *this to its
202   // uninitialized state.
203   void Clear();
204 
205   // Make a copy of the Node's props_ if props_ is shared with
206   // other nodes. This must be called before mutating properties,
207   // e.g. in AddAttr.
208   void MaybeCopyOnWrite();
209 
210   AttrValue* AddAttrHelper(const string& name);
211 
212   // A set of mutually exclusive classes for different kinds of nodes,
213   // class_ is initialized in the Node::Initialize routine based on the
214   // node's type_string().
215   enum NodeClass {
216     NC_UNINITIALIZED,
217     NC_SWITCH,
218     NC_MERGE,
219     NC_ENTER,
220     NC_EXIT,
221     NC_NEXT_ITERATION,
222     NC_LOOP_COND,
223     NC_CONTROL_TRIGGER,
224     NC_SEND,
225     NC_HOST_SEND,
226     NC_RECV,
227     NC_HOST_RECV,
228     NC_CONSTANT,
229     NC_VARIABLE,
230     NC_IDENTITY,
231     NC_GET_SESSION_HANDLE,
232     NC_GET_SESSION_TENSOR,
233     NC_DELETE_SESSION_TENSOR,
234     NC_METADATA,
235     NC_OTHER  // Not a special kind of node
236   };
237 
238   static const std::unordered_map<string, NodeClass>& kNodeClassTable;
239 
240   static NodeClass GetNodeClassForOp(const string& ts);
241 
242   int id_;       // -1 until Initialize() is called
243   int cost_id_;  // -1 if there is no corresponding cost accounting node
244   NodeClass class_;
245 
246   EdgeSet in_edges_;
247   EdgeSet out_edges_;
248 
249   // NOTE(skyewm): inheriting from core::RefCounted may have a slight
250   // performance benefit over using shared_ptr, at the cost of manual ref
251   // counting
252   std::shared_ptr<NodeProperties> props_;
253 
254   // Index within Graph::device_names_ of the name of device assigned
255   // to perform this computation.
256   int assigned_device_name_index_;
257 
258   // A back-pointer to the Graph that owns this node.  Currently, this exists
259   // solely to allow Node::[set_]assigned_device_name() to work. However, if all
260   // callers of Node::[set_]assigned_device_name() are modified to use the
261   // equivalent methods defined directly on Graph, then we can remove this
262   // field and reclaim that memory.
263   Graph* graph_;
264 
265   // Set if this is an exit node of a while loop with an associated
266   // WhileContext. Otherwise null. (This is only set for exit nodes because
267   // they're the first nodes of a loop encountered while creating the gradient
268   // graph. Exit nodes that are part of while loop gradient graphs will not have
269   // this set.)
270   WhileContext* while_ctx_;
271 
272   TF_DISALLOW_COPY_AND_ASSIGN(Node);
273 };
274 
275 // Represents an input of a node, i.e., the `index`-th input to `node`.
276 struct InputTensor {
277   const Node* node;
278   int index;
279 
InputTensorInputTensor280   InputTensor(const Node* n, int i) : node(n), index(i) {}
InputTensorInputTensor281   InputTensor() : node(nullptr), index(0) {}
282 };
283 
284 // Represents an output of a node, i.e., the `index`-th output of `node`. Note
285 // that a single `OutputTensor` can correspond to multiple `Edge`s if the output
286 // is consumed by multiple destination nodes.
287 struct OutputTensor {
288   const Node* node;
289   int index;
290 
OutputTensorOutputTensor291   OutputTensor(const Node* n, int i) : node(n), index(i) {}
OutputTensorOutputTensor292   OutputTensor() : node(nullptr), index(0) {}
293 };
294 
295 class Edge {
296  public:
src()297   Node* src() const { return src_; }
dst()298   Node* dst() const { return dst_; }
id()299   int id() const { return id_; }
300 
301   // Return the index of the source output that produces the data
302   // carried by this edge.  The special value kControlSlot is used
303   // for control dependencies.
src_output()304   int src_output() const { return src_output_; }
305 
306   // Return the index of the destination input that consumes the data
307   // carried by this edge.  The special value kControlSlot is used
308   // for control dependencies.
dst_input()309   int dst_input() const { return dst_input_; }
310 
311   // Return true iff this is an edge that indicates a control-flow
312   // (as opposed to a data-flow) dependency.
313   bool IsControlEdge() const;
314 
315   string DebugString() const;
316 
317  private:
Edge()318   Edge() {}
319 
320   friend class EdgeSetTest;
321   friend class Graph;
322   Node* src_;
323   Node* dst_;
324   int id_;
325   int src_output_;
326   int dst_input_;
327 };
328 
329 // Allows for iteration of the edges of a Graph, by iterating the underlying
330 // Graph.edges_ vector while skipping over null entries.
331 class GraphEdgesIterable {
332  private:
333   const std::vector<Edge*>& edges_;
334 
335  public:
GraphEdgesIterable(const std::vector<Edge * > & edges)336   explicit GraphEdgesIterable(const std::vector<Edge*>& edges)
337       : edges_(edges) {}
338 
339   typedef Edge* value_type;
340 
341   class const_iterator {
342    private:
343     // The underlying iterator.
344     std::vector<value_type>::const_iterator iter_;
345 
346     // The end of the underlying iterator.
347     std::vector<value_type>::const_iterator end_;
348 
349     // Advances iter_ until it reaches a non-null item, or reaches the end.
apply_filter()350     void apply_filter() {
351       while (iter_ != end_ && *iter_ == nullptr) {
352         ++iter_;
353       }
354     }
355 
356    public:
const_iterator(std::vector<value_type>::const_iterator iter,std::vector<value_type>::const_iterator end)357     const_iterator(std::vector<value_type>::const_iterator iter,
358                    std::vector<value_type>::const_iterator end)
359         : iter_(iter), end_(end) {
360       apply_filter();
361     }
362 
363     bool operator==(const const_iterator& other) const {
364       return iter_ == other.iter_;
365     }
366 
367     bool operator!=(const const_iterator& other) const {
368       return iter_ != other.iter_;
369     }
370 
371     // This is the prefix increment operator (++x), which is the operator
372     // used by C++ range iteration (for (x : y) ...).  We intentionally do not
373     // provide a postfix increment operator.
374     const_iterator& operator++() {
375       ++iter_;
376       apply_filter();
377       return *this;
378     }
379 
380     value_type operator*() { return *iter_; }
381   };
382 
begin()383   const_iterator begin() {
384     return const_iterator(edges_.begin(), edges_.end());
385   }
end()386   const_iterator end() { return const_iterator(edges_.end(), edges_.end()); }
387 };
388 
389 // Thread compatible but not thread safe.
390 class Graph {
391  public:
392   // Constructs a graph with a single SOURCE (always id kSourceId) and a
393   // single SINK (always id kSinkId) node, and an edge from SOURCE->SINK.
394   //
395   // The graph can hold ops found in registry. `registry`s lifetime must be at
396   // least that of the constructed graph's.
397   explicit Graph(const OpRegistryInterface* registry);
398 
399   // Constructs a graph with a single SOURCE (always id kSourceId) and a
400   // single SINK (always id kSinkId) node, and an edge from SOURCE->SINK.
401   //
402   // The graph can hold ops found in `flib_def`. Unlike the constructor taking
403   // an OpRegistryInterface, this constructor copies the function definitions in
404   // `flib_def` so its lifetime may be shorter than that of the graph's. The
405   // OpRegistryInterface backing `flib_def` must still have the lifetime of the
406   // graph though.
407   explicit Graph(const FunctionLibraryDefinition& flib_def);
408 
409   ~Graph();
410 
411   static const int kControlSlot;
412 
413   // The GraphDef version range of this graph (see graph.proto).
414   const VersionDef& versions() const;
415   void set_versions(const VersionDef& versions);
416 
417   // Adds a new node to this graph, and returns it. Infers the Op and
418   // input/output types for the node. *this owns the returned instance.
419   // Returns nullptr and sets *status on error.
420   Node* AddNode(const NodeDef& node_def, Status* status);
421 
422   // Copies *node, which may belong to another graph, to a new node,
423   // which is returned.  Does not copy any edges.  *this owns the
424   // returned instance.
425   Node* CopyNode(Node* node);
426 
427   // Removes a node from this graph, including all edges from or to it.
428   // *node should not be accessed after calling this function.
429   // REQUIRES: node->IsOp()
430   void RemoveNode(Node* node);
431 
432   // Adds an edge that connects the xth output of `source` to the yth input of
433   // `dest` and returns it. Does not update dest's NodeDef.
434   const Edge* AddEdge(Node* source, int x, Node* dest, int y);
435 
436   // Adds a control edge (no data flows along this edge) that connects `source`
437   // to `dest`. If `dest`s NodeDef is missing the corresponding control input,
438   // adds the control input.
439   //
440   // If such a control edge already exists and `allow_duplicates` is false, no
441   // edge is added and the function returns nullptr. Otherwise the edge is
442   // unconditionally created and returned. The NodeDef is not updated if
443   // `allow_duplicates` is true.
444   // TODO(skyewm): // TODO(skyewm): allow_duplicates is needed only by
445   // graph_partition.cc. Figure out if we can do away with it.
446   const Edge* AddControlEdge(Node* source, Node* dest,
447                              bool allow_duplicates = false);
448 
449   // Removes edge from the graph. Does not update the destination node's
450   // NodeDef.
451   // REQUIRES: The edge must exist.
452   void RemoveEdge(const Edge* edge);
453 
454   // Removes control edge `edge` from the graph. Note that this also updates
455   // the corresponding NodeDef to reflect the change.
456   // REQUIRES: The control edge must exist.
457   void RemoveControlEdge(const Edge* e);
458   // Updates the input to a node.  The existing edge to `dst` is removed and an
459   // edge from `new_src` to `dst` is created. The NodeDef associated with `dst`
460   // is also updated.
461   Status UpdateEdge(Node* new_src, int new_src_index, Node* dst, int dst_index);
462 
463   // Adds the function and gradient definitions in `fdef_lib` to this graph's op
464   // registry. Ignores duplicate functions, and returns a bad status if an
465   // imported function differs from an existing function or op with the same
466   // name.
467   Status AddFunctionLibrary(const FunctionDefLibrary& fdef_lib);
468 
469   // The number of live nodes in the graph.
470   //
471   // Because nodes can be removed from the graph, num_nodes() is often
472   // smaller than num_node_ids(). If one needs to create an array of
473   // nodes indexed by node ids, num_node_ids() should be used as the
474   // array's size.
num_nodes()475   int num_nodes() const { return num_nodes_; }
476 
477   // The number of live nodes in the graph, excluding the Source and Sink nodes.
num_op_nodes()478   int num_op_nodes() const {
479     DCHECK_GE(num_nodes_, 2);
480     return num_nodes_ - 2;
481   }
482 
483   // The number of live edges in the graph.
484   //
485   // Because edges can be removed from the graph, num_edges() is often
486   // smaller than num_edge_ids(). If one needs to create an array of
487   // edges indexed by edge ids, num_edge_ids() should be used as the
488   // array's size.
num_edges()489   int num_edges() const { return num_edges_; }
490 
491   // Serialize the nodes starting at `from_node_id` to a GraphDef.
492   void ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const;
493 
494   // Serialize to a GraphDef.
495   void ToGraphDef(GraphDef* graph_def) const;
496 
497   // This version can be called from debugger to inspect the graph content.
498   // Use the previous version outside debug context for efficiency reasons.
499   //
500   // Note: We do not expose a DebugString() API, since GraphDef.DebugString() is
501   // not defined in some TensorFlow builds.
502   GraphDef ToGraphDefDebug() const;
503 
504   // Generate new node name with the specified prefix that is unique
505   // across this graph.
506   string NewName(StringPiece prefix);
507 
508   // Access to the list of all nodes.  Example usage:
509   //   for (Node* node : graph.nodes()) { ... }
510   gtl::iterator_range<NodeIter> nodes() const;
511 
512   // Access to the list of all nodes, excluding the Source and Sink nodes.
513   gtl::iterator_range<NodeIter> op_nodes() const;
514 
515   // Returns one more than the maximum id assigned to any node.
num_node_ids()516   int num_node_ids() const { return nodes_.size(); }
517 
518   // Returns the node associated with an id, or nullptr if no node
519   // with that id (the node with that id was removed and the id has
520   // not yet been re-used). *this owns the returned instance.
521   // REQUIRES: 0 <= id < num_node_ids().
FindNodeId(int id)522   Node* FindNodeId(int id) const { return nodes_[id]; }
523 
524   // Returns one more than the maximum id assigned to any edge.
num_edge_ids()525   int num_edge_ids() const { return edges_.size(); }
526 
527   // Returns the Edge associated with an id, or nullptr if no edge
528   // with that id (the node with that id was removed and the id has
529   // not yet been re-used). *this owns the returned instance.
530   // REQUIRES: 0 <= id < num_node_ids().
FindEdgeId(int id)531   const Edge* FindEdgeId(int id) const { return edges_[id]; }
532 
533   // Access to the set of all edges.  Example usage:
534   //   for (const Edge* e : graph.edges()) { ... }
edges()535   GraphEdgesIterable edges() const { return GraphEdgesIterable(edges_); }
536 
537   // The pre-defined nodes.
538   enum { kSourceId = 0, kSinkId = 1 };
source_node()539   Node* source_node() const { return FindNodeId(kSourceId); }
sink_node()540   Node* sink_node() const { return FindNodeId(kSinkId); }
541 
op_registry()542   const OpRegistryInterface* op_registry() const { return &ops_; }
flib_def()543   const FunctionLibraryDefinition& flib_def() const { return ops_; }
544 
CheckDeviceNameIndex(int index)545   void CheckDeviceNameIndex(int index) {
546     DCHECK_GE(index, 0);
547     DCHECK_LT(index, static_cast<int>(device_names_.size()));
548   }
549 
550   int InternDeviceName(const string& device_name);
551 
get_assigned_device_name(const Node & node)552   const string& get_assigned_device_name(const Node& node) const {
553     return device_names_[node.assigned_device_name_index()];
554   }
555 
set_assigned_device_name_index(Node * node,int device_name_index)556   void set_assigned_device_name_index(Node* node, int device_name_index) {
557     CheckDeviceNameIndex(device_name_index);
558     node->assigned_device_name_index_ = device_name_index;
559   }
560 
set_assigned_device_name(Node * node,const string & device_name)561   void set_assigned_device_name(Node* node, const string& device_name) {
562     node->assigned_device_name_index_ = InternDeviceName(device_name);
563   }
564 
565   // Returns OK if `node` is non-null and belongs to this graph
566   Status IsValidNode(const Node* node) const;
567 
568   // Returns OK if IsValidNode(`node`) and `idx` is less than
569   // node->num_outputs()
570   Status IsValidOutputTensor(const Node* node, int idx) const;
571 
572   // Returns OK if IsValidNode(`node`) and `idx` is less than
573   // node->num_inputs()
574   Status IsValidInputTensor(const Node* node, int idx) const;
575 
576   // Create and return a new WhileContext owned by this graph. This is called
577   // when a new while loop is created. `frame_name` must be unique among
578   // WhileContexts in this graph.
579   Status AddWhileContext(StringPiece frame_name, std::vector<Node*> enter_nodes,
580                          std::vector<Node*> exit_nodes,
581                          OutputTensor cond_output,
582                          std::vector<OutputTensor> body_inputs,
583                          std::vector<OutputTensor> body_outputs,
584                          WhileContext** result);
585 
586   // TODO(josh11b): uint64 hash() const;
587 
588  private:
589   // If cost_node is non-null, then cost accounting (in CostModel)
590   // will be associated with that node rather than the new one being
591   // created.
592   //
593   // Ownership of the returned Node is not transferred to caller.
594   Node* AllocateNode(std::shared_ptr<NodeProperties> props,
595                      const Node* cost_node);
596   void ReleaseNode(Node* node);
597 
598   // Registry of all known ops, including functions.
599   FunctionLibraryDefinition ops_;
600 
601   // GraphDef versions
602   const std::unique_ptr<VersionDef> versions_;
603 
604   // Allocator which will give us good locality.
605   core::Arena arena_;
606 
607   // Map from node ids to allocated nodes.  nodes_[id] may be nullptr if
608   // the node with that id was removed from the graph.
609   std::vector<Node*> nodes_;
610 
611   // Number of nodes alive.
612   int64 num_nodes_ = 0;
613 
614   // Map from edge ids to allocated edges.  edges_[id] may be nullptr if
615   // the edge with that id was removed from the graph.
616   std::vector<Edge*> edges_;
617 
618   // The number of entries in edges_ that are not nullptr.
619   int num_edges_ = 0;
620 
621   // Allocated but free nodes and edges.
622   std::vector<Node*> free_nodes_;
623   std::vector<Edge*> free_edges_;
624 
625   // For generating unique names.
626   int name_counter_ = 0;
627 
628   // In most graphs, the number of unique values used for the
629   // Node::assigned_device_name() property is quite small.  If the graph is
630   // large, then this duplication of values can consume a significant amount of
631   // memory.  Instead, we represent the same information using an interning
632   // table, which consists of a vector of unique strings (device_names_), as
633   // well a map (device_names_map_) from unique strings to indices within the
634   // unique string table.
635   //
636   // The InternDeviceName() method handles adding a new entry into the table,
637   // or locating the index of an existing entry.
638   //
639   // The fact that Node::assigned_device_name() is implemented using an
640   // interning table is intentionally public.  This allows algorithms that
641   // frequently access this field to do so efficiently, especially for the case
642   // where the assigned_device_name of one Node is copied directly from that
643   // of another Node.
644 
645   // A table of the unique assigned device names.  Indices do NOT correspond
646   // to node IDs.  Index 0 is always the empty string.
647   std::vector<string> device_names_;
648 
649   // Maps unique device names to indices within device_names_[i].
650   std::unordered_map<string, int> device_names_map_;
651 
652   // All the while contexts owned by this graph, keyed by frame name,
653   // corresponding to all the while loops contained in this graph (including
654   // nested loops). The stored contexts are usually accessed via
655   // AddWhileContext() or Node::while_ctx(), but this manages the lifetime.
656   std::map<string, WhileContext> while_ctxs_;
657 
658   // Searches through edges_ for the Edge whose destination node and index
659   // matches dst. An edge with destination `dst` must exist in the graph.
660   const Edge* FindEdge(const Node* dst, int index);
661 
662   TF_DISALLOW_COPY_AND_ASSIGN(Graph);
663 };
664 
665 // TODO(josh11b): We may want to support keeping an index on various
666 // node/edge attributes in a graph, particularly node names.
667 
668 // Helper routines
669 
IsSource(const Node * node)670 inline bool IsSource(const Node* node) { return node->IsSource(); }
IsSink(const Node * node)671 inline bool IsSink(const Node* node) { return node->IsSink(); }
IsSwitch(const Node * node)672 inline bool IsSwitch(const Node* node) { return node->IsSwitch(); }
IsMerge(const Node * node)673 inline bool IsMerge(const Node* node) { return node->IsMerge(); }
IsEnter(const Node * node)674 inline bool IsEnter(const Node* node) { return node->IsEnter(); }
IsExit(const Node * node)675 inline bool IsExit(const Node* node) { return node->IsExit(); }
IsNextIteration(const Node * n)676 inline bool IsNextIteration(const Node* n) { return n->IsNextIteration(); }
IsLoopCond(const Node * node)677 inline bool IsLoopCond(const Node* node) { return node->IsLoopCond(); }
IsControlTrigger(const Node * n)678 inline bool IsControlTrigger(const Node* n) { return n->IsControlTrigger(); }
IsSend(const Node * node)679 inline bool IsSend(const Node* node) { return node->IsSend(); }
IsRecv(const Node * node)680 inline bool IsRecv(const Node* node) { return node->IsRecv(); }
IsHostSend(const Node * node)681 inline bool IsHostSend(const Node* node) { return node->IsHostSend(); }
IsHostRecv(const Node * node)682 inline bool IsHostRecv(const Node* node) { return node->IsHostRecv(); }
683 
684 // True for Nodes that mediate the transfer of values between processes.
IsTransferNode(const Node * n)685 inline bool IsTransferNode(const Node* n) { return IsSend(n) || IsRecv(n); }
686 
IsConstant(const Node * node)687 inline bool IsConstant(const Node* node) { return node->IsConstant(); }
IsVariable(const Node * node)688 inline bool IsVariable(const Node* node) { return node->IsVariable(); }
IsIdentity(const Node * node)689 inline bool IsIdentity(const Node* node) { return node->IsIdentity(); }
690 
691 // Returns true iff 'n' is a control flow node.
IsControlFlow(const Node * n)692 inline bool IsControlFlow(const Node* n) { return n->IsControlFlow(); }
693 
694 // Returns true if the node only depends on its input's metadata
695 // (shape).  Specifically, returns true for "Size", "Shape" and "Rank" ops.
IsMetadata(const Node * n)696 inline bool IsMetadata(const Node* n) { return n->IsMetadata(); }
697 
IsHostMemoryPreserving(const Node * node)698 inline bool IsHostMemoryPreserving(const Node* node) {
699   return IsIdentity(node) || IsControlFlow(node);
700 }
701 
702 // Iterator for stepping through the nodes of a graph.
703 class NodeIter {
704  public:
705   NodeIter(const Graph* graph, int id);
706   bool operator==(const NodeIter& rhs);
707   bool operator!=(const NodeIter& rhs);
708   void operator++();
709   Node* operator*();
710   Node* operator->();
711 
712  private:
713   // Invariant: id_ == graph_->num_node_ids() || graph_->FindId(id_) != nullptr
714   const Graph* graph_;
715   int id_;
716 };
717 
718 // Iterator for stepping through the neighbors of a node.
719 class NeighborIter {
720  public:
721   NeighborIter(EdgeSet::const_iterator iter, bool incoming);
722   bool operator==(const NeighborIter& rhs);
723   bool operator!=(const NeighborIter& rhs);
724   void operator++();
725   Node* operator*();
726   Node* operator->();
727 
728  private:
729   EdgeSet::const_iterator iter_;
730   bool incoming_;
731 };
732 
733 // IMPLEMENTATION DETAILS, PLEASE IGNORE
734 
NodeIter(const Graph * graph,int id)735 inline NodeIter::NodeIter(const Graph* graph, int id)
736     : graph_(graph), id_(id) {}
737 
738 inline bool NodeIter::operator==(const NodeIter& rhs) {
739   DCHECK(graph_ == rhs.graph_);
740   return id_ == rhs.id_;
741 }
742 
743 inline bool NodeIter::operator!=(const NodeIter& rhs) {
744   return !(*this == rhs);
745 }
746 
747 inline void NodeIter::operator++() {
748   while (1) {
749     DCHECK_LE(id_, graph_->num_node_ids());
750     ++id_;
751     if (id_ >= graph_->num_node_ids() || graph_->FindNodeId(id_) != nullptr) {
752       return;
753     }
754   }
755 }
756 
757 inline Node* NodeIter::operator*() { return graph_->FindNodeId(id_); }
758 
759 inline Node* NodeIter::operator->() { return graph_->FindNodeId(id_); }
760 
NeighborIter(EdgeSet::const_iterator iter,bool incoming)761 inline NeighborIter::NeighborIter(EdgeSet::const_iterator iter, bool incoming)
762     : iter_(iter), incoming_(incoming) {}
763 
764 inline bool NeighborIter::operator==(const NeighborIter& rhs) {
765   return iter_ == rhs.iter_ && incoming_ == rhs.incoming_;
766 }
767 
768 inline bool NeighborIter::operator!=(const NeighborIter& rhs) {
769   return !(*this == rhs);
770 }
771 
772 inline void NeighborIter::operator++() { ++iter_; }
773 
774 inline Node* NeighborIter::operator*() {
775   const Edge* e = *iter_;
776   return incoming_ ? e->src() : e->dst();
777 }
778 
779 inline Node* NeighborIter::operator->() {
780   const Edge* e = *iter_;
781   return incoming_ ? e->src() : e->dst();
782 }
783 
IsControlEdge()784 inline bool Edge::IsControlEdge() const {
785   // Note that if either src_output_ or dst_input_ is kControlSlot,
786   // so is the other one (AddEdge checks this).
787   return src_output_ == Graph::kControlSlot;
788 }
789 
nodes()790 inline gtl::iterator_range<NodeIter> Graph::nodes() const {
791   // Note that NodeId 0 is always valid since we don't let the source
792   // node be removed from the graph.
793   return gtl::make_range(NodeIter(this, 0), NodeIter(this, num_node_ids()));
794 }
795 
op_nodes()796 inline gtl::iterator_range<NodeIter> Graph::op_nodes() const {
797   // Note that NodeId 0 is always valid since we don't let the source
798   // node be removed from the graph.
799   //
800   // The current implementation of Graph maintains the invariant that the
801   // first two nodes are the source and sink nodes, and all other nodes are op
802   // nodes. This method (op_nodes()) relies on this invariant.
803   NodeIter begin(this, 0);
804   NodeIter end(this, num_node_ids());
805   if (begin != end) {
806     ++begin;
807   }
808   if (begin != end) {
809     ++begin;
810   }
811   return gtl::make_range(begin, end);
812 }
813 
set_assigned_device_name_index(int index)814 inline void Node::set_assigned_device_name_index(int index) {
815   graph_->CheckDeviceNameIndex(index);
816   assigned_device_name_index_ = index;
817 }
818 
set_assigned_device_name(const string & device_name)819 inline void Node::set_assigned_device_name(const string& device_name) {
820   graph_->set_assigned_device_name(this, device_name);
821 }
822 
assigned_device_name()823 inline const string& Node::assigned_device_name() const {
824   return graph_->get_assigned_device_name(*this);
825 }
826 
827 }  // namespace tensorflow
828 
829 #endif  // TENSORFLOW_GRAPH_GRAPH_H_
830