• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // A Graph describes a set of computations that are to be
17 // performed, as well as the dependencies between those
18 // computations. The basic model is a DAG (directed acyclic graph) with
19 // * internal nodes representing computational operations to be performed;
20 // * edges represent dependencies, indicating the target may only be
21 //   executed once the source has completed; and
22 // * predefined "source" (start) and "sink" (finish) nodes -- the source
23 //   should be the only node that doesn't depend on anything, and the sink
24 //   should be the only node that nothing depends on.
25 //
26 // Note: Node ids are intended to be relatively dense in the
27 // 0..max_id range, but there may be gaps since ids won't be reused.
28 //
29 // Note: Some dependencies between operations are due to one operation
30 // consuming the output of another. In fact operations can produce
31 // multiple outputs and consume multiple inputs, and some
32 // optimizations will care about which specific outputs are connected
33 // to which specific inputs.  We therefore represent data dependency
34 // between output O of layer A and input I of layer B using
35 // "input index" and "output index" labels per edge.
36 
37 #ifndef TENSORFLOW_CORE_GRAPH_GRAPH_H_
38 #define TENSORFLOW_CORE_GRAPH_GRAPH_H_
39 
40 #include <functional>
41 #include <string>
42 #include <vector>
43 
44 #include "tensorflow/core/framework/function.h"
45 #include "tensorflow/core/framework/node_def.pb.h"
46 #include "tensorflow/core/framework/op.h"
47 #include "tensorflow/core/framework/types.h"
48 #include "tensorflow/core/graph/edgeset.h"
49 #include "tensorflow/core/lib/core/arena.h"
50 #include "tensorflow/core/lib/core/refcount.h"
51 #include "tensorflow/core/lib/core/status.h"
52 #include "tensorflow/core/lib/gtl/iterator_range.h"
53 #include "tensorflow/core/platform/logging.h"
54 #include "tensorflow/core/platform/macros.h"
55 #include "tensorflow/core/platform/types.h"
56 
57 namespace tensorflow {
58 
59 class Edge;
60 class EdgeSetTest;
61 class Graph;
62 class GraphDef;
63 class Node;
64 struct OutputTensor;
65 class VersionDef;
66 class WhileContext;
67 
68 class NeighborIter;     // Declared below
69 class NodeIter;         // Declared below
70 struct NodeProperties;  // Defined in .cc
71 
72 class Node {
73  public:
74   string DebugString() const;
id()75   int id() const { return id_; }
cost_id()76   int cost_id() const { return cost_id_; }
77   const string& name() const;
78   void set_name(string name);
79   const string& type_string() const;
80 
81   // def() provides the NodeDef the user supplied, but the specifics
82   // of this Node may have changed due to placement, optimization, etc.
83   // In particular:
84   // * def().name() will match name();
85   // * def().op() will match type_string() and op_def().name();
86   // * def().input() is not reliable, use "in_edges()" below instead;
87   // * def().device() is the "user's requested device" and may not match
88   //   the actual assigned device, see assigned_device_name() below;
89   // * def().attr() is authoritative.
90   // TODO(irving): Replace with NodeInfo.
91   const NodeDef& def() const;
92   const OpDef& op_def() const;
93 
94   // input and output types
95   int32 num_inputs() const;
96   DataType input_type(int32 i) const;
97   const DataTypeVector& input_types() const;
98 
99   int32 num_outputs() const;
100   DataType output_type(int32 o) const;
101   const DataTypeVector& output_types() const;
102 
103   // The device requested by the user.  For the actual assigned device,
104   // use assigned_device_name() below.
105   const string& requested_device() const;
106 
107   // This changes the user requested device but not necessarily the device that
108   // on which the operation will run.
109   void set_requested_device(const string& device);
110 
111   // This gives the device the runtime has assigned this node to.  If
112   // you want the device the user requested, use def().device() instead.
113   // TODO(josh11b): Validate that the assigned_device, if not empty:
114   // fully specifies a device, and satisfies def().device().
115   // TODO(josh11b): Move assigned_device_name outside of Node into a
116   // NodeId->DeviceName map.
117   const string& assigned_device_name() const;
118   void set_assigned_device_name(const string& device_name);
has_assigned_device_name()119   bool has_assigned_device_name() const {
120     return assigned_device_name_index_ > 0;
121   }
assigned_device_name_index()122   int assigned_device_name_index() const { return assigned_device_name_index_; }
123   void set_assigned_device_name_index(int index);
124 
125   // Sets 'original_node_names' field of this node's DebugInfo proto to
126   // 'names'.
127   void set_original_node_names(const std::vector<string>& names);
128 
129   // Read only access to attributes
130   AttrSlice attrs() const;
131 
132   // Inputs requested by the NodeDef.  For the actual inputs, use in_edges.
133   const protobuf::RepeatedPtrField<string>& requested_inputs() const;
134 
135   // Get the neighboring nodes via edges either in or out of this node.  This
136   // includes control edges.
137   gtl::iterator_range<NeighborIter> in_nodes() const;
138   gtl::iterator_range<NeighborIter> out_nodes() const;
in_edges()139   const EdgeSet& in_edges() const { return in_edges_; }
out_edges()140   const EdgeSet& out_edges() const { return out_edges_; }
141 
142   // Node type helpers.
IsSource()143   bool IsSource() const { return id() == 0; }
IsSink()144   bool IsSink() const { return id() == 1; }
145   // Anything other than the special Source & Sink nodes.
IsOp()146   bool IsOp() const { return id() > 1; }
147 
148   // Node class helpers
IsSwitch()149   bool IsSwitch() const { return class_ == NC_SWITCH; }
IsMerge()150   bool IsMerge() const { return class_ == NC_MERGE; }
IsEnter()151   bool IsEnter() const { return class_ == NC_ENTER; }
IsExit()152   bool IsExit() const { return class_ == NC_EXIT; }
IsNextIteration()153   bool IsNextIteration() const { return class_ == NC_NEXT_ITERATION; }
IsLoopCond()154   bool IsLoopCond() const { return class_ == NC_LOOP_COND; }
IsControlTrigger()155   bool IsControlTrigger() const { return class_ == NC_CONTROL_TRIGGER; }
IsSend()156   bool IsSend() const { return class_ == NC_SEND || class_ == NC_HOST_SEND; }
IsRecv()157   bool IsRecv() const { return class_ == NC_RECV || class_ == NC_HOST_RECV; }
IsConstant()158   bool IsConstant() const { return class_ == NC_CONSTANT; }
IsVariable()159   bool IsVariable() const { return class_ == NC_VARIABLE; }
IsIdentity()160   bool IsIdentity() const { return class_ == NC_IDENTITY; }
IsGetSessionHandle()161   bool IsGetSessionHandle() const { return class_ == NC_GET_SESSION_HANDLE; }
IsGetSessionTensor()162   bool IsGetSessionTensor() const { return class_ == NC_GET_SESSION_TENSOR; }
IsDeleteSessionTensor()163   bool IsDeleteSessionTensor() const {
164     return class_ == NC_DELETE_SESSION_TENSOR;
165   }
IsControlFlow()166   bool IsControlFlow() const {
167     return (class_ != NC_OTHER) &&  // Fast path
168            (IsSwitch() || IsMerge() || IsEnter() || IsExit() ||
169             IsNextIteration());
170   }
IsHostSend()171   bool IsHostSend() const { return class_ == NC_HOST_SEND; }
IsHostRecv()172   bool IsHostRecv() const { return class_ == NC_HOST_RECV; }
IsScopedAllocator()173   bool IsScopedAllocator() const { return class_ == NC_SCOPED_ALLOCATOR; }
IsCollective()174   bool IsCollective() const { return class_ == NC_COLLECTIVE; }
175 
IsMetadata()176   bool IsMetadata() const { return class_ == NC_METADATA; }
IsFakeParam()177   bool IsFakeParam() const { return class_ == NC_FAKE_PARAM; }
IsPartitionedCall()178   bool IsPartitionedCall() const { return class_ == NC_PARTITIONED_CALL; }
179 
180   // Returns true if this node is any kind of function call node.
181   //
182   // NOTE: "function call nodes" include partitioned call ops, symbolic gradient
183   // ops, and ops whose type_string is the name of a function ("function ops").
IsFunctionCall()184   bool IsFunctionCall() const {
185     return class_ == NC_PARTITIONED_CALL || class_ == NC_FUNCTION_OP ||
186            class_ == NC_SYMBOLIC_GRADIENT;
187   }
188 
IsIfNode()189   bool IsIfNode() const { return class_ == NC_IF; }
IsWhileNode()190   bool IsWhileNode() const { return class_ == NC_WHILE; }
191   // Is this node a function input
IsArg()192   bool IsArg() const { return class_ == NC_ARG; }
193   // Is this node a function output
IsRetval()194   bool IsRetval() const { return class_ == NC_RETVAL; }
195 
196   template <typename T>
AddAttr(const string & name,const T & val)197   void AddAttr(const string& name, const T& val) {
198     SetAttrValue(val, AddAttrHelper(name));
199     UpdateProperties();
200   }
201 
AddAttr(const string & name,std::vector<string> && val)202   void AddAttr(const string& name, std::vector<string>&& val) {
203     MoveAttrValue(std::move(val), AddAttrHelper(name));
204     UpdateProperties();
205   }
206 
207   void ClearAttr(const string& name);
208 
209   // Returns into '*e' the edge connecting to the 'idx' input of this Node.
210   Status input_edge(int idx, const Edge** e) const;
211 
212   // Returns into '*edges' the input data edges of this Node, indexed by input
213   // number. Does not return control edges.
214   Status input_edges(std::vector<const Edge*>* edges) const;
215 
216   // Returns into '*n' the node that has an output connected to the
217   // 'idx' input of this Node.
218   Status input_node(int idx, const Node** n) const;
219   Status input_node(int idx, Node** n) const;
220 
221   // Returns into '*t' the idx-th input tensor of this node, represented as the
222   // output tensor of input_node(idx).
223   Status input_tensor(int idx, OutputTensor* t) const;
224 
while_ctx()225   WhileContext* while_ctx() const { return while_ctx_; }
set_while_ctx(WhileContext * while_ctx)226   void set_while_ctx(WhileContext* while_ctx) {
227     DCHECK(IsExit());
228     DCHECK(while_ctx_ == nullptr);
229     while_ctx_ = while_ctx;
230   }
231 
232  private:
233   friend class Graph;
234   Node();
235 
properties()236   NodeProperties* properties() const { return props_.get(); }
237 
238   void Initialize(int id, int cost_id, std::shared_ptr<NodeProperties> props,
239                   bool is_function_op);
240 
241   // Releases memory from props_, in addition to restoring *this to its
242   // uninitialized state.
243   void Clear();
244 
245   // Make a copy of the Node's props_ if props_ is shared with
246   // other nodes. This must be called before mutating properties,
247   // e.g. in AddAttr.
248   void MaybeCopyOnWrite();
249 
250   // Called after an attr has changed. Decides whether we need to update some
251   // property of the node (stored in props_).
252   void UpdateProperties();
253 
254   AttrValue* AddAttrHelper(const string& name);
255 
256   // A set of mutually exclusive classes for different kinds of nodes,
257   // class_ is initialized in the Node::Initialize routine based on the
258   // node's type_string().
259   enum NodeClass {
260     NC_UNINITIALIZED,
261     NC_SWITCH,
262     NC_MERGE,
263     NC_ENTER,
264     NC_EXIT,
265     NC_NEXT_ITERATION,
266     NC_LOOP_COND,
267     NC_CONTROL_TRIGGER,
268     NC_SEND,
269     NC_HOST_SEND,
270     NC_RECV,
271     NC_HOST_RECV,
272     NC_CONSTANT,
273     NC_VARIABLE,
274     NC_IDENTITY,
275     NC_GET_SESSION_HANDLE,
276     NC_GET_SESSION_TENSOR,
277     NC_DELETE_SESSION_TENSOR,
278     NC_METADATA,
279     NC_SCOPED_ALLOCATOR,
280     NC_COLLECTIVE,
281     NC_FAKE_PARAM,
282     NC_PARTITIONED_CALL,
283     NC_FUNCTION_OP,
284     NC_SYMBOLIC_GRADIENT,
285     NC_IF,
286     NC_WHILE,
287     NC_ARG,
288     NC_RETVAL,
289     NC_OTHER  // Not a special kind of node
290   };
291 
292   static const std::unordered_map<string, NodeClass>& kNodeClassTable;
293 
294   static NodeClass GetNodeClassForOp(const string& ts);
295 
296   int id_;       // -1 until Initialize() is called
297   int cost_id_;  // -1 if there is no corresponding cost accounting node
298   NodeClass class_;
299 
300   EdgeSet in_edges_;
301   EdgeSet out_edges_;
302 
303   // NOTE(skyewm): inheriting from core::RefCounted may have a slight
304   // performance benefit over using shared_ptr, at the cost of manual ref
305   // counting
306   std::shared_ptr<NodeProperties> props_;
307 
308   // Index within Graph::device_names_ of the name of device assigned
309   // to perform this computation.
310   int assigned_device_name_index_;
311 
312   // A back-pointer to the Graph that owns this node.  Currently, this exists
313   // solely to allow Node::[set_]assigned_device_name() to work. However, if all
314   // callers of Node::[set_]assigned_device_name() are modified to use the
315   // equivalent methods defined directly on Graph, then we can remove this
316   // field and reclaim that memory.
317   Graph* graph_;
318 
319   // Set if this is an exit node of a while loop with an associated
320   // WhileContext. Otherwise null. (This is only set for exit nodes because
321   // they're the first nodes of a loop encountered while creating the gradient
322   // graph. Exit nodes that are part of while loop gradient graphs will not have
323   // this set.)
324   WhileContext* while_ctx_;
325 
326   TF_DISALLOW_COPY_AND_ASSIGN(Node);
327 };
328 
329 // Stores debug information associated with the Node.
330 struct NodeDebugInfo {
331   const string name;
332   std::vector<string> original_node_names;
333 
334   NodeDebugInfo(const Node& n);
335   NodeDebugInfo(const NodeDef& ndef);
336   NodeDebugInfo(StringPiece node_name, bool has_experimental_debug_info,
337                 const NodeDef_ExperimentalDebugInfo& experimental_debug_info);
338 };
339 
340 // Represents an input of a node, i.e., the `index`-th input to `node`.
341 struct InputTensor {
342   Node* node;
343   int index;
344 
InputTensorInputTensor345   InputTensor(Node* n, int i) : node(n), index(i) {}
InputTensorInputTensor346   InputTensor() : node(nullptr), index(0) {}
347 
348   // Returns true if this InputTensor is identical to 'other'. Nodes are
349   // compared using pointer equality.
350   bool operator==(const InputTensor& other) const;
351 
352   // A hash function for InputTensors. Nodes are hashed based on their pointer
353   // value.
354   struct Hash {
355     uint64 operator()(InputTensor const& s) const;
356   };
357 };
358 
359 // Represents an output of a node, i.e., the `index`-th output of `node`. Note
360 // that a single `OutputTensor` can correspond to multiple `Edge`s if the output
361 // is consumed by multiple destination nodes.
362 struct OutputTensor {
363   Node* node;
364   int index;
365 
OutputTensorOutputTensor366   OutputTensor(Node* n, int i) : node(n), index(i) {}
OutputTensorOutputTensor367   OutputTensor() : node(nullptr), index(0) {}
368 
369   // Returns true if this OutputTensor is identical to 'other'. Nodes are
370   // compared using pointer equality.
371   bool operator==(const OutputTensor& other) const;
372 
373   // A hash function for OutputTensors. Nodes are hashed based on their pointer
374   // value.
375   struct Hash {
376     uint64 operator()(OutputTensor const& s) const;
377   };
378 };
379 
380 class Edge {
381  public:
src()382   Node* src() const { return src_; }
dst()383   Node* dst() const { return dst_; }
id()384   int id() const { return id_; }
385 
386   // Return the index of the source output that produces the data
387   // carried by this edge.  The special value kControlSlot is used
388   // for control dependencies.
src_output()389   int src_output() const { return src_output_; }
390 
391   // Return the index of the destination input that consumes the data
392   // carried by this edge.  The special value kControlSlot is used
393   // for control dependencies.
dst_input()394   int dst_input() const { return dst_input_; }
395 
396   // Return true iff this is an edge that indicates a control-flow
397   // (as opposed to a data-flow) dependency.
398   bool IsControlEdge() const;
399 
400   string DebugString() const;
401 
402  private:
Edge()403   Edge() {}
404 
405   friend class EdgeSetTest;
406   friend class Graph;
407   Node* src_;
408   Node* dst_;
409   int id_;
410   int src_output_;
411   int dst_input_;
412 };
413 
414 // Allows for iteration of the edges of a Graph, by iterating the underlying
415 // Graph.edges_ vector while skipping over null entries.
416 class GraphEdgesIterable {
417  private:
418   const std::vector<Edge*>& edges_;
419 
420  public:
GraphEdgesIterable(const std::vector<Edge * > & edges)421   explicit GraphEdgesIterable(const std::vector<Edge*>& edges)
422       : edges_(edges) {}
423 
424   typedef Edge* value_type;
425 
426   class const_iterator {
427    private:
428     // The underlying iterator.
429     std::vector<value_type>::const_iterator iter_;
430 
431     // The end of the underlying iterator.
432     std::vector<value_type>::const_iterator end_;
433 
434     // Advances iter_ until it reaches a non-null item, or reaches the end.
apply_filter()435     void apply_filter() {
436       while (iter_ != end_ && *iter_ == nullptr) {
437         ++iter_;
438       }
439     }
440 
441    public:
const_iterator(std::vector<value_type>::const_iterator iter,std::vector<value_type>::const_iterator end)442     const_iterator(std::vector<value_type>::const_iterator iter,
443                    std::vector<value_type>::const_iterator end)
444         : iter_(iter), end_(end) {
445       apply_filter();
446     }
447 
448     bool operator==(const const_iterator& other) const {
449       return iter_ == other.iter_;
450     }
451 
452     bool operator!=(const const_iterator& other) const {
453       return iter_ != other.iter_;
454     }
455 
456     // This is the prefix increment operator (++x), which is the operator
457     // used by C++ range iteration (for (x : y) ...).  We intentionally do not
458     // provide a postfix increment operator.
459     const_iterator& operator++() {
460       ++iter_;
461       apply_filter();
462       return *this;
463     }
464 
465     value_type operator*() { return *iter_; }
466   };
467 
begin()468   const_iterator begin() {
469     return const_iterator(edges_.begin(), edges_.end());
470   }
end()471   const_iterator end() { return const_iterator(edges_.end(), edges_.end()); }
472 };
473 
474 // Thread compatible but not thread safe.
475 class Graph {
476  public:
477   // Constructs a graph with a single SOURCE (always id kSourceId) and a
478   // single SINK (always id kSinkId) node, and an edge from SOURCE->SINK.
479   //
480   // The graph can hold ops found in the registry. `ops`s lifetime must be at
481   // least that of the constructed graph's.
482   explicit Graph(const OpRegistryInterface* ops);
483 
484   // Constructs a graph with a single SOURCE (always id kSourceId) and a
485   // single SINK (always id kSinkId) node, and an edge from SOURCE->SINK.
486   //
487   // The graph can hold ops found in `flib_def`. Unlike the constructor taking
488   // an OpRegistryInterface, this constructor copies the function definitions in
489   // `flib_def` so its lifetime may be shorter than that of the graph's. The
490   // OpRegistryInterface backing `flib_def` must still have the lifetime of the
491   // graph though.
492   explicit Graph(const FunctionLibraryDefinition& flib_def);
493 
494   ~Graph();
495 
496   static const int kControlSlot;
497 
498   // The GraphDef version range of this graph (see graph.proto).
499   const VersionDef& versions() const;
500   void set_versions(const VersionDef& versions);
501 
502   // Adds a new node to this graph, and returns it. Infers the Op and
503   // input/output types for the node. *this owns the returned instance.
504   // Returns nullptr and sets *status on error.
505   Node* AddNode(NodeDef node_def, Status* status);
506 
507   // Copies *node, which may belong to another graph, to a new node,
508   // which is returned.  Does not copy any edges.  *this owns the
509   // returned instance.
510   Node* CopyNode(const Node* node);
511 
512   // Removes a node from this graph, including all edges from or to it.
513   // *node should not be accessed after calling this function.
514   // REQUIRES: node->IsOp()
515   void RemoveNode(Node* node);
516 
517   // Adds an edge that connects the xth output of `source` to the yth input of
518   // `dest` and returns it. Does not update dest's NodeDef.
519   const Edge* AddEdge(Node* source, int x, Node* dest, int y);
520 
521   // Adds a control edge (no data flows along this edge) that connects `source`
522   // to `dest`. If `dest`s NodeDef is missing the corresponding control input,
523   // adds the control input.
524   //
525   // If such a control edge already exists and `allow_duplicates` is false, no
526   // edge is added and the function returns nullptr. Otherwise the edge is
527   // unconditionally created and returned. The NodeDef is not updated if
528   // `allow_duplicates` is true.
529   // TODO(skyewm): // TODO(skyewm): allow_duplicates is needed only by
530   // graph_partition.cc. Figure out if we can do away with it.
531   const Edge* AddControlEdge(Node* source, Node* dest,
532                              bool allow_duplicates = false);
533 
534   // Removes edge from the graph. Does not update the destination node's
535   // NodeDef.
536   // REQUIRES: The edge must exist.
537   void RemoveEdge(const Edge* edge);
538 
539   // Removes control edge `edge` from the graph. Note that this also updates
540   // the corresponding NodeDef to reflect the change.
541   // REQUIRES: The control edge must exist.
542   void RemoveControlEdge(const Edge* e);
543 
544   // Updates the input to a node.  The existing edge to `dst` is removed and an
545   // edge from `new_src` to `dst` is created. The NodeDef associated with `dst`
546   // is also updated.
547   Status UpdateEdge(Node* new_src, int new_src_index, Node* dst, int dst_index);
548 
549   // Like AddEdge but updates dst's NodeDef. Used to add an input edge to a
550   // "While" op during gradient construction, see AddInputWhileHack in
551   // python_api.h for more details.
552   Status AddWhileInputHack(Node* new_src, int new_src_index, Node* dst);
553 
554   // Adds the function and gradient definitions in `fdef_lib` to this graph's op
555   // registry. Ignores duplicate functions, and returns a bad status if an
556   // imported function differs from an existing function or op with the same
557   // name.
558   Status AddFunctionLibrary(const FunctionDefLibrary& fdef_lib);
559 
560   // The number of live nodes in the graph.
561   //
562   // Because nodes can be removed from the graph, num_nodes() is often
563   // smaller than num_node_ids(). If one needs to create an array of
564   // nodes indexed by node ids, num_node_ids() should be used as the
565   // array's size.
num_nodes()566   int num_nodes() const { return num_nodes_; }
567 
568   // The number of live nodes in the graph, excluding the Source and Sink nodes.
num_op_nodes()569   int num_op_nodes() const {
570     DCHECK_GE(num_nodes_, 2);
571     return num_nodes_ - 2;
572   }
573 
574   // The number of live edges in the graph.
575   //
576   // Because edges can be removed from the graph, num_edges() is often
577   // smaller than num_edge_ids(). If one needs to create an array of
578   // edges indexed by edge ids, num_edge_ids() should be used as the
579   // array's size.
num_edges()580   int num_edges() const { return num_edges_; }
581 
582   // Serialize the nodes starting at `from_node_id` to a GraphDef.
583   void ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const;
584 
585   // Serialize to a GraphDef.
586   void ToGraphDef(GraphDef* graph_def) const;
587 
588   // This version can be called from debugger to inspect the graph content.
589   // Use the previous version outside debug context for efficiency reasons.
590   //
591   // Note: We do not expose a DebugString() API, since GraphDef.DebugString() is
592   // not defined in some TensorFlow builds.
593   GraphDef ToGraphDefDebug() const;
594 
595   // Generate new node name with the specified prefix that is unique
596   // across this graph.
597   string NewName(StringPiece prefix);
598 
599   // Access to the list of all nodes.  Example usage:
600   //   for (Node* node : graph.nodes()) { ... }
601   gtl::iterator_range<NodeIter> nodes() const;
602 
603   // Access to the list of all nodes, excluding the Source and Sink nodes.
604   gtl::iterator_range<NodeIter> op_nodes() const;
605 
606   // Returns one more than the maximum id assigned to any node.
num_node_ids()607   int num_node_ids() const { return nodes_.size(); }
608 
609   // Returns the node associated with an id, or nullptr if no node
610   // with that id (the node with that id was removed and the id has
611   // not yet been re-used). *this owns the returned instance.
612   // REQUIRES: 0 <= id < num_node_ids().
FindNodeId(int id)613   Node* FindNodeId(int id) const { return nodes_[id]; }
614 
615   // Returns one more than the maximum id assigned to any edge.
num_edge_ids()616   int num_edge_ids() const { return edges_.size(); }
617 
618   // Returns the Edge associated with an id, or nullptr if no edge
619   // with that id (the node with that id was removed and the id has
620   // not yet been re-used). *this owns the returned instance.
621   // REQUIRES: 0 <= id < num_node_ids().
FindEdgeId(int id)622   const Edge* FindEdgeId(int id) const { return edges_[id]; }
623 
624   // Access to the set of all edges.  Example usage:
625   //   for (const Edge* e : graph.edges()) { ... }
edges()626   GraphEdgesIterable edges() const { return GraphEdgesIterable(edges_); }
627 
628   // The pre-defined nodes.
629   enum { kSourceId = 0, kSinkId = 1 };
source_node()630   Node* source_node() const { return FindNodeId(kSourceId); }
sink_node()631   Node* sink_node() const { return FindNodeId(kSinkId); }
632 
op_registry()633   const OpRegistryInterface* op_registry() const { return &ops_; }
flib_def()634   const FunctionLibraryDefinition& flib_def() const { return ops_; }
635 
CheckDeviceNameIndex(int index)636   void CheckDeviceNameIndex(int index) {
637     DCHECK_GE(index, 0);
638     DCHECK_LT(index, static_cast<int>(device_names_.size()));
639   }
640 
641   int InternDeviceName(const string& device_name);
642 
get_assigned_device_name(const Node & node)643   const string& get_assigned_device_name(const Node& node) const {
644     return device_names_[node.assigned_device_name_index()];
645   }
646 
set_assigned_device_name_index(Node * node,int device_name_index)647   void set_assigned_device_name_index(Node* node, int device_name_index) {
648     CheckDeviceNameIndex(device_name_index);
649     node->assigned_device_name_index_ = device_name_index;
650   }
651 
set_assigned_device_name(Node * node,const string & device_name)652   void set_assigned_device_name(Node* node, const string& device_name) {
653     node->assigned_device_name_index_ = InternDeviceName(device_name);
654   }
655 
656   // Returns OK if `node` is non-null and belongs to this graph
657   Status IsValidNode(const Node* node) const;
658 
659   // Returns OK if IsValidNode(`node`) and `idx` is a valid output.  Does not
660   // accept control outputs.
661   Status IsValidOutputTensor(const Node* node, int idx) const;
662 
663   // Returns OK if IsValidNode(`node`) and `idx` a valid input.  Does not accept
664   // control inputs.
665   Status IsValidInputTensor(const Node* node, int idx) const;
666 
667   // Create and return a new WhileContext owned by this graph. This is called
668   // when a new while loop is created. `frame_name` must be unique among
669   // WhileContexts in this graph.
670   Status AddWhileContext(StringPiece frame_name, std::vector<Node*> enter_nodes,
671                          std::vector<Node*> exit_nodes,
672                          OutputTensor cond_output,
673                          std::vector<OutputTensor> body_inputs,
674                          std::vector<OutputTensor> body_outputs,
675                          WhileContext** result);
676 
677   // Builds a node name to node pointer index for all nodes in the graph.
678   std::unordered_map<string, Node*> BuildNodeNameIndex() const;
679 
680   // TODO(josh11b): uint64 hash() const;
681 
682  private:
683   // If cost_node is non-null, then cost accounting (in CostModel)
684   // will be associated with that node rather than the new one being
685   // created.
686   //
687   // Ownership of the returned Node is not transferred to caller.
688   Node* AllocateNode(std::shared_ptr<NodeProperties> props,
689                      const Node* cost_node, bool is_function_op);
690   void ReleaseNode(Node* node);
691   // Insert edge in free_edges_ for possible reuse.
692   void RecycleEdge(const Edge* edge);
693   // Registry of all known ops, including functions.
694   FunctionLibraryDefinition ops_;
695 
696   // GraphDef versions
697   const std::unique_ptr<VersionDef> versions_;
698 
699   // Allocator which will give us good locality.
700   core::Arena arena_;
701 
702   // Map from node ids to allocated nodes.  nodes_[id] may be nullptr if
703   // the node with that id was removed from the graph.
704   std::vector<Node*> nodes_;
705 
706   // Number of nodes alive.
707   int64 num_nodes_ = 0;
708 
709   // Map from edge ids to allocated edges.  edges_[id] may be nullptr if
710   // the edge with that id was removed from the graph.
711   std::vector<Edge*> edges_;
712 
713   // The number of entries in edges_ that are not nullptr.
714   int num_edges_ = 0;
715 
716   // Allocated but free nodes and edges.
717   std::vector<Node*> free_nodes_;
718   std::vector<Edge*> free_edges_;
719 
720   // For generating unique names.
721   int name_counter_ = 0;
722 
723   // In most graphs, the number of unique values used for the
724   // Node::assigned_device_name() property is quite small.  If the graph is
725   // large, then this duplication of values can consume a significant amount of
726   // memory.  Instead, we represent the same information using an interning
727   // table, which consists of a vector of unique strings (device_names_), as
728   // well a map (device_names_map_) from unique strings to indices within the
729   // unique string table.
730   //
731   // The InternDeviceName() method handles adding a new entry into the table,
732   // or locating the index of an existing entry.
733   //
734   // The fact that Node::assigned_device_name() is implemented using an
735   // interning table is intentionally public.  This allows algorithms that
736   // frequently access this field to do so efficiently, especially for the case
737   // where the assigned_device_name of one Node is copied directly from that
738   // of another Node.
739 
740   // A table of the unique assigned device names.  Indices do NOT correspond
741   // to node IDs.  Index 0 is always the empty string.
742   std::vector<string> device_names_;
743 
744   // Maps unique device names to indices within device_names_[i].
745   std::unordered_map<string, int> device_names_map_;
746 
747   // All the while contexts owned by this graph, keyed by frame name,
748   // corresponding to all the while loops contained in this graph (including
749   // nested loops). The stored contexts are usually accessed via
750   // AddWhileContext() or Node::while_ctx(), but this manages the lifetime.
751   std::map<string, WhileContext> while_ctxs_;
752 
753   TF_DISALLOW_COPY_AND_ASSIGN(Graph);
754 };
755 
756 // TODO(josh11b): We may want to support keeping an index on various
757 // node/edge attributes in a graph, particularly node names.
758 
759 // Helper routines
760 
IsSource(const Node * node)761 inline bool IsSource(const Node* node) { return node->IsSource(); }
IsSink(const Node * node)762 inline bool IsSink(const Node* node) { return node->IsSink(); }
IsSwitch(const Node * node)763 inline bool IsSwitch(const Node* node) { return node->IsSwitch(); }
IsMerge(const Node * node)764 inline bool IsMerge(const Node* node) { return node->IsMerge(); }
IsEnter(const Node * node)765 inline bool IsEnter(const Node* node) { return node->IsEnter(); }
IsExit(const Node * node)766 inline bool IsExit(const Node* node) { return node->IsExit(); }
IsNextIteration(const Node * n)767 inline bool IsNextIteration(const Node* n) { return n->IsNextIteration(); }
IsLoopCond(const Node * node)768 inline bool IsLoopCond(const Node* node) { return node->IsLoopCond(); }
IsControlTrigger(const Node * n)769 inline bool IsControlTrigger(const Node* n) { return n->IsControlTrigger(); }
IsSend(const Node * node)770 inline bool IsSend(const Node* node) { return node->IsSend(); }
IsRecv(const Node * node)771 inline bool IsRecv(const Node* node) { return node->IsRecv(); }
IsHostSend(const Node * node)772 inline bool IsHostSend(const Node* node) { return node->IsHostSend(); }
IsHostRecv(const Node * node)773 inline bool IsHostRecv(const Node* node) { return node->IsHostRecv(); }
774 
775 // True for Nodes that mediate the transfer of values between processes.
IsTransferNode(const Node * n)776 inline bool IsTransferNode(const Node* n) { return IsSend(n) || IsRecv(n); }
777 
IsConstant(const Node * node)778 inline bool IsConstant(const Node* node) { return node->IsConstant(); }
IsVariable(const Node * node)779 inline bool IsVariable(const Node* node) { return node->IsVariable(); }
IsIdentity(const Node * node)780 inline bool IsIdentity(const Node* node) { return node->IsIdentity(); }
781 
782 // Returns true iff 'n' is a control flow node.
IsControlFlow(const Node * n)783 inline bool IsControlFlow(const Node* n) { return n->IsControlFlow(); }
784 
785 // Returns true if the node only depends on its input's metadata
786 // (shape).  Specifically, returns true for "Size", "Shape" and "Rank" ops.
IsMetadata(const Node * n)787 inline bool IsMetadata(const Node* n) { return n->IsMetadata(); }
788 
IsScopedAllocator(const Node * n)789 inline bool IsScopedAllocator(const Node* n) { return n->IsScopedAllocator(); }
790 
IsHostMemoryPreserving(const Node * node)791 inline bool IsHostMemoryPreserving(const Node* node) {
792   return IsIdentity(node) || IsControlFlow(node);
793 }
794 
795 // NOTE: We declare Reference type of NodeIter and NeighborIter as Node* (see
796 // https://en.cppreference.com/w/cpp/iterator/iterator).
797 
798 // Iterator for stepping through the nodes of a graph.
799 class NodeIter
800     : public std::iterator<std::forward_iterator_tag, Node, std::ptrdiff_t,
801                            /*Pointer*/ Node*, /*Reference*/ Node*> {
802  public:
803   NodeIter(const Graph* graph, int id);
804   bool operator==(const NodeIter& rhs) const;
805   bool operator!=(const NodeIter& rhs) const;
806   void operator++();
807   reference operator*() const;
808   pointer operator->() const;
809 
810  private:
811   // Invariant: id_ == graph_->num_node_ids() || graph_->FindId(id_) != nullptr
812   const Graph* graph_;
813   int id_;
814 };
815 
816 // Iterator for stepping through the neighbors of a node.
817 class NeighborIter
818     : public std::iterator<std::forward_iterator_tag, Node, std::ptrdiff_t,
819                            /*Pointer*/ Node*, /*Reference*/ Node*> {
820  public:
821   NeighborIter(EdgeSet::const_iterator iter, bool incoming);
822   bool operator==(const NeighborIter& rhs) const;
823   bool operator!=(const NeighborIter& rhs) const;
824   void operator++();
825   reference operator*() const;
826   pointer operator->() const;
827 
828  private:
829   EdgeSet::const_iterator iter_;
830   bool incoming_;
831 };
832 
833 // IMPLEMENTATION DETAILS, PLEASE IGNORE
834 
NodeIter(const Graph * graph,int id)835 inline NodeIter::NodeIter(const Graph* graph, int id)
836     : graph_(graph), id_(id) {}
837 
838 inline bool NodeIter::operator==(const NodeIter& rhs) const {
839   DCHECK(graph_ == rhs.graph_);
840   return id_ == rhs.id_;
841 }
842 
843 inline bool NodeIter::operator!=(const NodeIter& rhs) const {
844   return !(*this == rhs);
845 }
846 
847 inline void NodeIter::operator++() {
848   while (1) {
849     DCHECK_LE(id_, graph_->num_node_ids());
850     ++id_;
851     if (id_ >= graph_->num_node_ids() || graph_->FindNodeId(id_) != nullptr) {
852       return;
853     }
854   }
855 }
856 
857 inline Node* NodeIter::operator*() const { return graph_->FindNodeId(id_); }
858 
859 inline Node* NodeIter::operator->() const { return graph_->FindNodeId(id_); }
860 
NeighborIter(EdgeSet::const_iterator iter,bool incoming)861 inline NeighborIter::NeighborIter(EdgeSet::const_iterator iter, bool incoming)
862     : iter_(iter), incoming_(incoming) {}
863 
864 inline bool NeighborIter::operator==(const NeighborIter& rhs) const {
865   return iter_ == rhs.iter_ && incoming_ == rhs.incoming_;
866 }
867 
868 inline bool NeighborIter::operator!=(const NeighborIter& rhs) const {
869   return !(*this == rhs);
870 }
871 
872 inline void NeighborIter::operator++() { ++iter_; }
873 
874 inline Node* NeighborIter::operator*() const {
875   const Edge* e = *iter_;
876   return incoming_ ? e->src() : e->dst();
877 }
878 
879 inline Node* NeighborIter::operator->() const {
880   const Edge* e = *iter_;
881   return incoming_ ? e->src() : e->dst();
882 }
883 
IsControlEdge()884 inline bool Edge::IsControlEdge() const {
885   // Note that if either src_output_ or dst_input_ is kControlSlot,
886   // so is the other one (AddEdge checks this).
887   return src_output_ == Graph::kControlSlot;
888 }
889 
nodes()890 inline gtl::iterator_range<NodeIter> Graph::nodes() const {
891   // Note that NodeId 0 is always valid since we don't let the source
892   // node be removed from the graph.
893   return gtl::make_range(NodeIter(this, 0), NodeIter(this, num_node_ids()));
894 }
895 
op_nodes()896 inline gtl::iterator_range<NodeIter> Graph::op_nodes() const {
897   // Note that NodeId 0 is always valid since we don't let the source
898   // node be removed from the graph.
899   //
900   // The current implementation of Graph maintains the invariant that the
901   // first two nodes are the source and sink nodes, and all other nodes are op
902   // nodes. This method (op_nodes()) relies on this invariant.
903   NodeIter begin(this, 0);
904   NodeIter end(this, num_node_ids());
905   if (begin != end) {
906     ++begin;
907   }
908   if (begin != end) {
909     ++begin;
910   }
911   return gtl::make_range(begin, end);
912 }
913 
set_assigned_device_name_index(int index)914 inline void Node::set_assigned_device_name_index(int index) {
915   graph_->CheckDeviceNameIndex(index);
916   assigned_device_name_index_ = index;
917 }
918 
set_assigned_device_name(const string & device_name)919 inline void Node::set_assigned_device_name(const string& device_name) {
920   graph_->set_assigned_device_name(this, device_name);
921 }
922 
assigned_device_name()923 inline const string& Node::assigned_device_name() const {
924   return graph_->get_assigned_device_name(*this);
925 }
926 
927 }  // namespace tensorflow
928 
929 #endif  // TENSORFLOW_CORE_GRAPH_GRAPH_H_
930