• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // A Graph describes a set of computations that are to be
17 // performed, as well as the dependencies between those
18 // computations. The basic model is a DAG (directed acyclic graph) with
19 // * internal nodes representing computational operations to be performed;
20 // * edges represent dependencies, indicating the target may only be
21 //   executed once the source has completed; and
22 // * predefined "source" (start) and "sink" (finish) nodes -- the source
23 //   should be the only node that doesn't depend on anything, and the sink
24 //   should be the only node that nothing depends on.
25 //
26 // Note: Node ids are intended to be relatively dense in the
27 // 0..max_id range, but there may be gaps since ids won't be reused.
28 //
29 // Note: Some dependencies between operations are due to one operation
30 // consuming the output of another. In fact operations can produce
31 // multiple outputs and consume multiple inputs, and some
32 // optimizations will care about which specific outputs are connected
33 // to which specific inputs.  We therefore represent data dependency
34 // between output O of layer A and input I of layer B using
35 // "input index" and "output index" labels per edge.
36 
37 #ifndef TENSORFLOW_CORE_GRAPH_GRAPH_H_
38 #define TENSORFLOW_CORE_GRAPH_GRAPH_H_
39 
40 #include <functional>
41 #include <string>
42 #include <vector>
43 
44 #include "absl/types/optional.h"
45 #include "tensorflow/core/framework/function.h"
46 #include "tensorflow/core/framework/node_def.pb.h"
47 #include "tensorflow/core/framework/node_def_util.h"
48 #include "tensorflow/core/framework/op.h"
49 #include "tensorflow/core/framework/types.h"
50 #include "tensorflow/core/graph/edgeset.h"
51 #include "tensorflow/core/lib/core/arena.h"
52 #include "tensorflow/core/lib/core/refcount.h"
53 #include "tensorflow/core/lib/core/status.h"
54 #include "tensorflow/core/lib/gtl/iterator_range.h"
55 #include "tensorflow/core/platform/logging.h"
56 #include "tensorflow/core/platform/macros.h"
57 #include "tensorflow/core/platform/types.h"
58 
59 namespace tensorflow {
60 
61 class Edge;
62 class EdgeSetTest;
63 class Graph;
64 class GraphDef;
65 class Node;
66 struct OutputTensor;
67 class VersionDef;
68 class WhileContext;
69 
70 class NeighborIter;     // Declared below
71 class NodeIter;         // Declared below
72 
73 // Indicates where the graph instance is originated from.
74 enum class ConstructionContext {
75   kNotTracked,     // Not tracked.
76   kDirectSession,  // From `tensorflow::DirectSession`, TF1 session API.
77   kEagerRuntime,   // Registered from TF2 eager runtime.
78 };
79 
80 class Node {
81  public:
82   std::string DebugString() const;
id()83   int id() const { return id_; }
cost_id()84   int cost_id() const { return cost_id_; }
85   const std::string& name() const;
86   void set_name(std::string name);
87   const std::string& type_string() const;
88 
89   // def() provides the NodeDef the user supplied, but the specifics
90   // of this Node may have changed due to placement, optimization, etc.
91   // In particular:
92   // * def().name() will match name();
93   // * def().op() will match type_string() and op_def().name();
94   // * def().input() is not reliable, use "in_edges()" below instead;
95   // * def().device() is the "user's requested device" and may not match
96   //   the actual assigned device, see assigned_device_name() below;
97   // * def().attr() is authoritative.
98   // TODO(irving): Replace with NodeInfo.
99   const NodeDef& def() const;
100   const OpDef& op_def() const;
101 
102   // input and output types
103   int32 num_inputs() const;
104   DataType input_type(int32 i) const;
105   const DataTypeVector& input_types() const;
106 
107   int32 num_outputs() const;
108   DataType output_type(int32 o) const;
109   const DataTypeVector& output_types() const;
110 
111   // The device requested by the user.  For the actual assigned device,
112   // use assigned_device_name() below.
113   const std::string& requested_device() const;
114 
115   // This changes the user requested device but not necessarily the device that
116   // on which the operation will run.
117   void set_requested_device(const std::string& device);
118 
119   // This gives the device the runtime has assigned this node to.  If
120   // you want the device the user requested, use def().device() instead.
121   // TODO(josh11b): Validate that the assigned_device, if not empty:
122   // fully specifies a device, and satisfies def().device().
123   // TODO(josh11b): Move assigned_device_name outside of Node into a
124   // NodeId->DeviceName map.
125   const std::string& assigned_device_name() const;
126   void set_assigned_device_name(const std::string& device_name);
has_assigned_device_name()127   bool has_assigned_device_name() const {
128     return assigned_device_name_index_ > 0;
129   }
assigned_device_name_index()130   int assigned_device_name_index() const { return assigned_device_name_index_; }
131   void set_assigned_device_name_index(int index);
132 
133   // Sets 'original_node_names' field of this node's DebugInfo proto to
134   // 'names'.
135   void set_original_node_names(const std::vector<string>& names);
136 
137   // Read only access to attributes
138   AttrSlice attrs() const;
139 
140   // Inputs requested by the NodeDef.  For the actual inputs, use in_edges.
141   const protobuf::RepeatedPtrField<string>& requested_inputs() const;
142 
143   // Get the neighboring nodes via edges either in or out of this node.  This
144   // includes control edges.
145   gtl::iterator_range<NeighborIter> in_nodes() const;
146   gtl::iterator_range<NeighborIter> out_nodes() const;
in_edges()147   const EdgeSet& in_edges() const { return in_edges_; }
out_edges()148   const EdgeSet& out_edges() const { return out_edges_; }
149 
150   // Node type helpers.
IsSource()151   bool IsSource() const { return id() == 0; }
IsSink()152   bool IsSink() const { return id() == 1; }
153   // Anything other than the special Source & Sink nodes.
IsOp()154   bool IsOp() const { return id() > 1; }
155 
156   // Node class helpers
IsSwitch()157   bool IsSwitch() const { return class_ == NC_SWITCH; }
IsMerge()158   bool IsMerge() const { return class_ == NC_MERGE; }
IsEnter()159   bool IsEnter() const { return class_ == NC_ENTER; }
IsExit()160   bool IsExit() const { return class_ == NC_EXIT; }
IsNextIteration()161   bool IsNextIteration() const { return class_ == NC_NEXT_ITERATION; }
IsLoopCond()162   bool IsLoopCond() const { return class_ == NC_LOOP_COND; }
IsControlTrigger()163   bool IsControlTrigger() const { return class_ == NC_CONTROL_TRIGGER; }
IsSend()164   bool IsSend() const { return class_ == NC_SEND || class_ == NC_HOST_SEND; }
IsRecv()165   bool IsRecv() const { return class_ == NC_RECV || class_ == NC_HOST_RECV; }
IsConstant()166   bool IsConstant() const { return class_ == NC_CONSTANT; }
IsVariable()167   bool IsVariable() const { return class_ == NC_VARIABLE; }
IsIdentity()168   bool IsIdentity() const { return class_ == NC_IDENTITY; }
IsGetSessionHandle()169   bool IsGetSessionHandle() const { return class_ == NC_GET_SESSION_HANDLE; }
IsGetSessionTensor()170   bool IsGetSessionTensor() const { return class_ == NC_GET_SESSION_TENSOR; }
IsDeleteSessionTensor()171   bool IsDeleteSessionTensor() const {
172     return class_ == NC_DELETE_SESSION_TENSOR;
173   }
IsControlFlow()174   bool IsControlFlow() const {
175     return (class_ != NC_OTHER) &&  // Fast path
176            (IsSwitch() || IsMerge() || IsEnter() || IsExit() ||
177             IsNextIteration());
178   }
IsHostSend()179   bool IsHostSend() const { return class_ == NC_HOST_SEND; }
IsHostRecv()180   bool IsHostRecv() const { return class_ == NC_HOST_RECV; }
IsScopedAllocator()181   bool IsScopedAllocator() const { return class_ == NC_SCOPED_ALLOCATOR; }
IsCollective()182   bool IsCollective() const { return class_ == NC_COLLECTIVE; }
183 
IsMetadata()184   bool IsMetadata() const { return class_ == NC_METADATA; }
IsFakeParam()185   bool IsFakeParam() const { return class_ == NC_FAKE_PARAM; }
IsPartitionedCall()186   bool IsPartitionedCall() const { return class_ == NC_PARTITIONED_CALL; }
187 
188   // Returns true if this node is any kind of function call node.
189   //
190   // NOTE: "function call nodes" include partitioned call ops, symbolic gradient
191   // ops, and ops whose type_string is the name of a function ("function ops").
IsFunctionCall()192   bool IsFunctionCall() const {
193     return class_ == NC_PARTITIONED_CALL || class_ == NC_FUNCTION_OP ||
194            class_ == NC_SYMBOLIC_GRADIENT;
195   }
196 
IsIfNode()197   bool IsIfNode() const { return class_ == NC_IF; }
IsWhileNode()198   bool IsWhileNode() const { return class_ == NC_WHILE; }
IsCaseNode()199   bool IsCaseNode() const { return class_ == NC_CASE; }
200   // Is this node a function input
IsArg()201   bool IsArg() const { return class_ == NC_ARG; }
202   // Is this node a function output
IsRetval()203   bool IsRetval() const { return class_ == NC_RETVAL; }
204 
205   template <typename T>
AddAttr(const std::string & name,const T & val)206   void AddAttr(const std::string& name, const T& val) {
207     SetAttrValue(val, AddAttrHelper(name));
208     UpdateProperties();
209   }
210 
AddAttr(const std::string & name,std::vector<string> && val)211   void AddAttr(const std::string& name, std::vector<string>&& val) {
212     MoveAttrValue(std::move(val), AddAttrHelper(name));
213     UpdateProperties();
214   }
215 
216   void ClearAttr(const std::string& name);
217 
218   // Returns into '*e' the edge connecting to the 'idx' input of this Node.
219   Status input_edge(int idx, const Edge** e) const;
220 
221   // Returns into '*edges' the input data edges of this Node, indexed by input
222   // number. Does not return control edges.
223   Status input_edges(std::vector<const Edge*>* edges) const;
224 
225   // Returns into '*n' the node that has an output connected to the
226   // 'idx' input of this Node.
227   Status input_node(int idx, const Node** n) const;
228   Status input_node(int idx, Node** n) const;
229 
230   // Returns into '*t' the idx-th input tensor of this node, represented as the
231   // output tensor of input_node(idx).
232   Status input_tensor(int idx, OutputTensor* t) const;
233 
while_ctx()234   WhileContext* while_ctx() const { return while_ctx_; }
set_while_ctx(WhileContext * while_ctx)235   void set_while_ctx(WhileContext* while_ctx) {
236     DCHECK(IsExit());
237     DCHECK(while_ctx_ == nullptr);
238     while_ctx_ = while_ctx;
239   }
240 
properties()241   std::shared_ptr<NodeProperties> properties() const { return props_; }
242 
243   // Sets the stack trace for the node. Assumes that getting and setting the
244   // stack trace for a given node will not race.
SetStackTrace(const std::shared_ptr<AbstractStackTrace> & stack_trace)245   void SetStackTrace(const std::shared_ptr<AbstractStackTrace>& stack_trace) {
246     stack_trace_ = stack_trace;
247   }
248 
249   // Get the stack trace for when the node was instantiated.
GetStackTrace()250   const std::shared_ptr<AbstractStackTrace>& GetStackTrace() const {
251     return stack_trace_;
252   }
253 
254  private:
255   friend class Graph;
256   Node();
257 
258   // Stack trace for the user code for node instantiation. Can be shared across
259   // multiple nodes (e.g. when inlining).
260   std::shared_ptr<AbstractStackTrace> stack_trace_;
261 
262   // Releases memory from props_, in addition to restoring *this to its
263   // uninitialized state.
264   void Clear();
265 
266   // Make a copy of the Node's props_ if props_ is shared with
267   // other nodes. This must be called before mutating properties,
268   // e.g. in AddAttr.
269   void MaybeCopyOnWrite();
270 
271   // Called after an attr has changed. Decides whether we need to update some
272   // property of the node (stored in props_).
273   void UpdateProperties();
274 
275   AttrValue* AddAttrHelper(const std::string& name);
276 
277   // A set of mutually exclusive classes for different kinds of nodes,
278   // class_ is initialized in the Node::Initialize routine based on the
279   // node's type_string().
280   enum NodeClass {
281     NC_UNINITIALIZED,
282     NC_SWITCH,
283     NC_MERGE,
284     NC_ENTER,
285     NC_EXIT,
286     NC_NEXT_ITERATION,
287     NC_LOOP_COND,
288     NC_CONTROL_TRIGGER,
289     NC_SEND,
290     NC_HOST_SEND,
291     NC_RECV,
292     NC_HOST_RECV,
293     NC_CONSTANT,
294     NC_VARIABLE,
295     NC_IDENTITY,
296     NC_GET_SESSION_HANDLE,
297     NC_GET_SESSION_TENSOR,
298     NC_DELETE_SESSION_TENSOR,
299     NC_METADATA,
300     NC_SCOPED_ALLOCATOR,
301     NC_COLLECTIVE,
302     NC_FAKE_PARAM,
303     NC_PARTITIONED_CALL,
304     NC_FUNCTION_OP,
305     NC_SYMBOLIC_GRADIENT,
306     NC_IF,
307     NC_WHILE,
308     NC_CASE,
309     NC_ARG,
310     NC_RETVAL,
311     NC_OTHER  // Not a special kind of node
312   };
313 
314   void Initialize(int id, int cost_id, std::shared_ptr<NodeProperties> props,
315                   NodeClass node_class);
316 
317   static NodeClass GetNodeClassForOp(const std::string& ts);
318 
319   int id_;       // -1 until Initialize() is called
320   int cost_id_;  // -1 if there is no corresponding cost accounting node
321   NodeClass class_;
322 
323   EdgeSet in_edges_;
324   EdgeSet out_edges_;
325 
326   // NOTE(skyewm): inheriting from core::RefCounted may have a slight
327   // performance benefit over using shared_ptr, at the cost of manual ref
328   // counting
329   std::shared_ptr<NodeProperties> props_;
330 
331   // Index within Graph::device_names_ of the name of device assigned
332   // to perform this computation.
333   int assigned_device_name_index_;
334 
335   // A back-pointer to the Graph that owns this node.  Currently, this exists
336   // solely to allow Node::[set_]assigned_device_name() to work. However, if all
337   // callers of Node::[set_]assigned_device_name() are modified to use the
338   // equivalent methods defined directly on Graph, then we can remove this
339   // field and reclaim that memory.
340   Graph* graph_;
341 
342   // Set if this is an exit node of a while loop with an associated
343   // WhileContext. Otherwise null. (This is only set for exit nodes because
344   // they're the first nodes of a loop encountered while creating the gradient
345   // graph. Exit nodes that are part of while loop gradient graphs will not have
346   // this set.)
347   WhileContext* while_ctx_;
348 
349   TF_DISALLOW_COPY_AND_ASSIGN(Node);
350 };
351 
352 // Stores debug information associated with the Node.
353 struct NodeDebugInfo {
354   const std::string name;
355   std::vector<string> original_node_names;
356 
357   NodeDebugInfo(const Node& n);
358   NodeDebugInfo(const NodeDef& ndef);
359   NodeDebugInfo(StringPiece node_name, bool has_experimental_debug_info,
360                 const NodeDef_ExperimentalDebugInfo& experimental_debug_info);
361 };
362 
363 // Represents an input of a node, i.e., the `index`-th input to `node`.
364 struct InputTensor {
365   Node* node;
366   int index;
367 
InputTensorInputTensor368   InputTensor(Node* n, int i) : node(n), index(i) {}
InputTensorInputTensor369   InputTensor() : node(nullptr), index(0) {}
370 
371   // Returns true if this InputTensor is identical to 'other'. Nodes are
372   // compared using pointer equality.
373   bool operator==(const InputTensor& other) const;
374 
375   // A hash function for InputTensors. Nodes are hashed based on their pointer
376   // value.
377   struct Hash {
378     uint64 operator()(InputTensor const& s) const;
379   };
380 };
381 
382 // Represents an output of a node, i.e., the `index`-th output of `node`. Note
383 // that a single `OutputTensor` can correspond to multiple `Edge`s if the output
384 // is consumed by multiple destination nodes.
385 struct OutputTensor {
386   Node* node;
387   int index;
388 
OutputTensorOutputTensor389   OutputTensor(Node* n, int i) : node(n), index(i) {}
OutputTensorOutputTensor390   OutputTensor() : node(nullptr), index(0) {}
391 
392   // Returns true if this OutputTensor is identical to 'other'. Nodes are
393   // compared using pointer equality.
394   bool operator==(const OutputTensor& other) const;
395 
396   // A hash function for OutputTensors. Nodes are hashed based on their pointer
397   // value.
398   struct Hash {
399     uint64 operator()(OutputTensor const& s) const;
400   };
401 };
402 
403 class Edge {
404  public:
src()405   Node* src() const { return src_; }
dst()406   Node* dst() const { return dst_; }
id()407   int id() const { return id_; }
408 
409   // Return the index of the source output that produces the data
410   // carried by this edge.  The special value kControlSlot is used
411   // for control dependencies.
src_output()412   int src_output() const { return src_output_; }
413 
414   // Return the index of the destination input that consumes the data
415   // carried by this edge.  The special value kControlSlot is used
416   // for control dependencies.
dst_input()417   int dst_input() const { return dst_input_; }
418 
419   // Return true iff this is an edge that indicates a control-flow
420   // (as opposed to a data-flow) dependency.
421   bool IsControlEdge() const;
422 
423   std::string DebugString() const;
424 
425  private:
Edge()426   Edge() {}
427 
428   friend class EdgeSetTest;
429   friend class Graph;
430   Node* src_;
431   Node* dst_;
432   int id_;
433   int src_output_;
434   int dst_input_;
435 };
436 
437 // Allows for iteration of the edges of a Graph, by iterating the underlying
438 // Graph.edges_ vector while skipping over null entries.
439 class GraphEdgesIterable {
440  private:
441   const std::vector<Edge*>& edges_;
442 
443  public:
GraphEdgesIterable(const std::vector<Edge * > & edges)444   explicit GraphEdgesIterable(const std::vector<Edge*>& edges)
445       : edges_(edges) {}
446 
447   typedef Edge* value_type;
448 
449   class const_iterator {
450    private:
451     // The underlying iterator.
452     std::vector<value_type>::const_iterator iter_;
453 
454     // The end of the underlying iterator.
455     std::vector<value_type>::const_iterator end_;
456 
457     // Advances iter_ until it reaches a non-null item, or reaches the end.
apply_filter()458     void apply_filter() {
459       while (iter_ != end_ && *iter_ == nullptr) {
460         ++iter_;
461       }
462     }
463 
464    public:
const_iterator(std::vector<value_type>::const_iterator iter,std::vector<value_type>::const_iterator end)465     const_iterator(std::vector<value_type>::const_iterator iter,
466                    std::vector<value_type>::const_iterator end)
467         : iter_(iter), end_(end) {
468       apply_filter();
469     }
470 
471     bool operator==(const const_iterator& other) const {
472       return iter_ == other.iter_;
473     }
474 
475     bool operator!=(const const_iterator& other) const {
476       return iter_ != other.iter_;
477     }
478 
479     // This is the prefix increment operator (++x), which is the operator
480     // used by C++ range iteration (for (x : y) ...).  We intentionally do not
481     // provide a postfix increment operator.
482     const_iterator& operator++() {
483       ++iter_;
484       apply_filter();
485       return *this;
486     }
487 
488     value_type operator*() { return *iter_; }
489   };
490 
begin()491   const_iterator begin() {
492     return const_iterator(edges_.begin(), edges_.end());
493   }
end()494   const_iterator end() { return const_iterator(edges_.end(), edges_.end()); }
495 };
496 
497 // Thread compatible but not thread safe.
498 class Graph {
499  public:
500   // Constructs a graph with a single SOURCE (always id kSourceId) and a
501   // single SINK (always id kSinkId) node, and an edge from SOURCE->SINK.
502   //
503   // The graph can hold ops found in the registry. `ops`s lifetime must be at
504   // least that of the constructed graph's.
505   explicit Graph(const OpRegistryInterface* ops);
506 
507   // Constructs a graph with a single SOURCE (always id kSourceId) and a
508   // single SINK (always id kSinkId) node, and an edge from SOURCE->SINK.
509   //
510   // The graph can hold ops found in `flib_def`. Unlike the constructor taking
511   // an OpRegistryInterface, this constructor copies the function definitions in
512   // `flib_def` so its lifetime may be shorter than that of the graph's. The
513   // OpRegistryInterface backing `flib_def` must still have the lifetime of the
514   // graph though.
515   explicit Graph(const FunctionLibraryDefinition& flib_def);
516 
517   ~Graph();
518 
519   static const int kControlSlot;
520 
521   // The GraphDef version range of this graph (see graph.proto).
522   const VersionDef& versions() const;
523   void set_versions(const VersionDef& versions);
524 
525   // Adds a new node to this graph, and returns it. Infers the Op and
526   // input/output types for the node. *this owns the returned instance.
527   // Returns nullptr and sets *status on error.
528   Node* AddNode(NodeDef node_def, Status* status);
529 
530   // Copies *node, which may belong to another graph, to a new node,
531   // which is returned.  Does not copy any edges.  *this owns the
532   // returned instance.
533   Node* CopyNode(const Node* node);
534 
535   // Removes a node from this graph, including all edges from or to it.
536   // *node should not be accessed after calling this function.
537   // REQUIRES: node->IsOp()
538   void RemoveNode(Node* node);
539 
540   void Copy(const Graph& src);
541 
542   // Adds an edge that connects the xth output of `source` to the yth input of
543   // `dest` and returns it. Does not update dest's NodeDef.
544   const Edge* AddEdge(Node* source, int x, Node* dest, int y);
545 
546   // Adds a control edge (no data flows along this edge) that connects `source`
547   // to `dest`. If `dest`s NodeDef is missing the corresponding control input,
548   // adds the control input.
549   //
550   // If such a control edge already exists and `allow_duplicates` is false, no
551   // edge is added and the function returns nullptr. Otherwise the edge is
552   // unconditionally created and returned. The NodeDef is not updated if
553   // `allow_duplicates` is true.
554   // TODO(skyewm): // TODO(skyewm): allow_duplicates is needed only by
555   // graph_partition.cc. Figure out if we can do away with it.
556   const Edge* AddControlEdge(Node* source, Node* dest,
557                              bool allow_duplicates = false);
558 
559   // Removes edge from the graph. Does not update the destination node's
560   // NodeDef.
561   // REQUIRES: The edge must exist.
562   void RemoveEdge(const Edge* edge);
563 
564   // Removes control edge `edge` from the graph. Note that this also updates
565   // the corresponding NodeDef to reflect the change.
566   // REQUIRES: The control edge must exist.
567   void RemoveControlEdge(const Edge* e);
568 
569   // Updates the input to a node.  The existing edge to `dst` is removed and an
570   // edge from `new_src` to `dst` is created. The NodeDef associated with `dst`
571   // is also updated.
572   Status UpdateEdge(Node* new_src, int new_src_index, Node* dst, int dst_index);
573 
574   // Like AddEdge but updates dst's NodeDef. Used to add an input edge to a
575   // "While" op during gradient construction, see AddInputWhileHack in
576   // python_api.h for more details.
577   Status AddWhileInputHack(Node* new_src, int new_src_index, Node* dst);
578 
579   // Adds the function and gradient definitions in `fdef_lib` to this graph's op
580   // registry. Ignores duplicate functions, and returns a bad status if an
581   // imported function differs from an existing function or op with the same
582   // name.
583   Status AddFunctionLibrary(const FunctionDefLibrary& fdef_lib);
584 
585   // The number of live nodes in the graph.
586   //
587   // Because nodes can be removed from the graph, num_nodes() is often
588   // smaller than num_node_ids(). If one needs to create an array of
589   // nodes indexed by node ids, num_node_ids() should be used as the
590   // array's size.
num_nodes()591   int num_nodes() const { return num_nodes_; }
592 
593   // The number of live nodes in the graph, excluding the Source and Sink nodes.
num_op_nodes()594   int num_op_nodes() const {
595     DCHECK_GE(num_nodes_, 2);
596     return num_nodes_ - 2;
597   }
598 
599   // The number of live edges in the graph.
600   //
601   // Because edges can be removed from the graph, num_edges() is often
602   // smaller than num_edge_ids(). If one needs to create an array of
603   // edges indexed by edge ids, num_edge_ids() should be used as the
604   // array's size.
num_edges()605   int num_edges() const { return num_edges_; }
606 
607   // Serialize the nodes starting at `from_node_id` to a GraphDef.
608   void ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const;
609 
610   // Serialize to a GraphDef.
611   void ToGraphDef(GraphDef* graph_def) const;
612 
613   // This version can be called from debugger to inspect the graph content.
614   // Use the previous version outside debug context for efficiency reasons.
615   //
616   // Note: We do not expose a DebugString() API, since GraphDef.DebugString() is
617   // not defined in some TensorFlow builds.
618   GraphDef ToGraphDefDebug() const;
619 
620   // Generate new node name with the specified prefix that is unique
621   // across this graph.
622   std::string NewName(StringPiece prefix);
623 
624   // Access to the list of all nodes.  Example usage:
625   //   for (Node* node : graph.nodes()) { ... }
626   gtl::iterator_range<NodeIter> nodes() const;
627 
628   // Access to the list of all nodes, excluding the Source and Sink nodes.
629   gtl::iterator_range<NodeIter> op_nodes() const;
630 
631   // Returns one more than the maximum id assigned to any node.
num_node_ids()632   int num_node_ids() const { return nodes_.size(); }
633 
634   // Returns the node associated with an id, or nullptr if no node
635   // with that id (the node with that id was removed and the id has
636   // not yet been re-used). *this owns the returned instance.
637   // REQUIRES: 0 <= id < num_node_ids().
FindNodeId(int id)638   Node* FindNodeId(int id) const { return nodes_[id]; }
639 
640   // Returns one more than the maximum id assigned to any edge.
num_edge_ids()641   int num_edge_ids() const { return edges_.size(); }
642 
643   // Returns the Edge associated with an id, or nullptr if no edge
644   // with that id (the edge with that id was removed and the id has
645   // not yet been re-used). *this owns the returned instance.
646   // REQUIRES: 0 <= id < num_edge_ids().
FindEdgeId(int id)647   const Edge* FindEdgeId(int id) const { return edges_[id]; }
648 
649   // Access to the set of all edges.  Example usage:
650   //   for (const Edge* e : graph.edges()) { ... }
edges()651   GraphEdgesIterable edges() const { return GraphEdgesIterable(edges_); }
652 
653   // The pre-defined nodes.
654   enum { kSourceId = 0, kSinkId = 1 };
source_node()655   Node* source_node() const { return FindNodeId(kSourceId); }
sink_node()656   Node* sink_node() const { return FindNodeId(kSinkId); }
657 
op_registry()658   const OpRegistryInterface* op_registry() const { return &ops_; }
flib_def()659   const FunctionLibraryDefinition& flib_def() const { return ops_; }
660 
CheckDeviceNameIndex(int index)661   void CheckDeviceNameIndex(int index) {
662     DCHECK_GE(index, 0);
663     DCHECK_LT(index, static_cast<int>(device_names_.size()));
664   }
665 
666   int InternDeviceName(const std::string& device_name);
667 
get_assigned_device_name(const Node & node)668   const std::string& get_assigned_device_name(const Node& node) const {
669     return device_names_[node.assigned_device_name_index()];
670   }
671 
set_assigned_device_name_index(Node * node,int device_name_index)672   void set_assigned_device_name_index(Node* node, int device_name_index) {
673     CheckDeviceNameIndex(device_name_index);
674     node->assigned_device_name_index_ = device_name_index;
675   }
676 
set_assigned_device_name(Node * node,const std::string & device_name)677   void set_assigned_device_name(Node* node, const std::string& device_name) {
678     node->assigned_device_name_index_ = InternDeviceName(device_name);
679   }
680 
681   // Returns OK if `node` is non-null and belongs to this graph
682   Status IsValidNode(const Node* node) const;
683 
684   // Returns OK if IsValidNode(`node`) and `idx` is a valid output.  Does not
685   // accept control outputs.
686   Status IsValidOutputTensor(const Node* node, int idx) const;
687 
688   // Returns OK if IsValidNode(`node`) and `idx` a valid input.  Does not accept
689   // control inputs.
690   Status IsValidInputTensor(const Node* node, int idx) const;
691 
692   // Create and return a new WhileContext owned by this graph. This is called
693   // when a new while loop is created. `frame_name` must be unique among
694   // WhileContexts in this graph.
695   Status AddWhileContext(StringPiece frame_name, std::vector<Node*> enter_nodes,
696                          std::vector<Node*> exit_nodes,
697                          OutputTensor cond_output,
698                          std::vector<OutputTensor> body_inputs,
699                          std::vector<OutputTensor> body_outputs,
700                          WhileContext** result);
701 
702   // Builds a node name to node pointer index for all nodes in the graph.
703   std::unordered_map<string, Node*> BuildNodeNameIndex() const;
704 
GetConstArgIndicesCache()705   absl::optional<std::vector<bool>>& GetConstArgIndicesCache() const {
706     return const_arg_indices_cache_;
707   }
708 
709   // TODO(kkb): Add to the constructor when it becomes managable.
710   // Sets the graph construction context.
SetConstructionContext(ConstructionContext construction_context)711   void SetConstructionContext(ConstructionContext construction_context) {
712     construction_context_ = construction_context;
713   }
714 
715   // TODO(kkb): Rename to `GetConstructionContext` once we're comfortable
716   // making this stable and make it available widely.
717   // Returns the graph construction context. It's `kUnknown` if not set.
GetConstructionContextInternal()718   ConstructionContext GetConstructionContextInternal() const {
719     return construction_context_;
720   }
721 
722   // TODO(josh11b): uint64 hash() const;
723 
724  private:
725   // If cost_node is non-null, then cost accounting (in CostModel)
726   // will be associated with that node rather than the new one being
727   // created.
728   //
729   // Ownership of the returned Node is not transferred to caller.
730   Node* AllocateNode(std::shared_ptr<NodeProperties> props,
731                      const Node* cost_node, Node::NodeClass node_class);
732   void ReleaseNode(Node* node);
733   // Insert edge in free_edges_ for possible reuse.
734   void RecycleEdge(const Edge* edge);
735   // Registry of all known ops, including functions.
736   FunctionLibraryDefinition ops_;
737 
738   // GraphDef versions
739   const std::unique_ptr<VersionDef> versions_;
740 
741   // Allocator which will give us good locality.
742   core::Arena arena_;
743 
744   // Map from node ids to allocated nodes.  nodes_[id] may be nullptr if
745   // the node with that id was removed from the graph.
746   std::vector<Node*> nodes_;
747 
748   // Number of nodes alive.
749   int64 num_nodes_ = 0;
750 
751   // Map from edge ids to allocated edges.  edges_[id] may be nullptr if
752   // the edge with that id was removed from the graph.
753   std::vector<Edge*> edges_;
754 
755   // The number of entries in edges_ that are not nullptr.
756   int num_edges_ = 0;
757 
758   // Allocated but free nodes and edges.
759   std::vector<Node*> free_nodes_;
760   std::vector<Edge*> free_edges_;
761 
762   // For generating unique names.
763   int name_counter_ = 0;
764 
765   // In most graphs, the number of unique values used for the
766   // Node::assigned_device_name() property is quite small.  If the graph is
767   // large, then this duplication of values can consume a significant amount of
768   // memory.  Instead, we represent the same information using an interning
769   // table, which consists of a vector of unique strings (device_names_), as
770   // well a map (device_names_map_) from unique strings to indices within the
771   // unique string table.
772   //
773   // The InternDeviceName() method handles adding a new entry into the table,
774   // or locating the index of an existing entry.
775   //
776   // The fact that Node::assigned_device_name() is implemented using an
777   // interning table is intentionally public.  This allows algorithms that
778   // frequently access this field to do so efficiently, especially for the case
779   // where the assigned_device_name of one Node is copied directly from that
780   // of another Node.
781 
782   // A table of the unique assigned device names.  Indices do NOT correspond
783   // to node IDs.  Index 0 is always the empty string.
784   std::vector<string> device_names_;
785 
786   // Maps unique device names to indices within device_names_[i].
787   std::unordered_map<string, int> device_names_map_;
788 
789   // All the while contexts owned by this graph, keyed by frame name,
790   // corresponding to all the while loops contained in this graph (including
791   // nested loops). The stored contexts are usually accessed via
792   // AddWhileContext() or Node::while_ctx(), but this manages the lifetime.
793   std::map<string, WhileContext> while_ctxs_;
794 
795   // Cache of the indices of the arguments which need to be constant for the XLA
796   // compilation.
797   mutable absl::optional<std::vector<bool>> const_arg_indices_cache_;
798 
799   // Indicates the context that this Graph instance is constructed.
800   ConstructionContext construction_context_ = ConstructionContext::kNotTracked;
801 
802   TF_DISALLOW_COPY_AND_ASSIGN(Graph);
803 };
804 
805 // TODO(josh11b): We may want to support keeping an index on various
806 // node/edge attributes in a graph, particularly node names.
807 
808 // Helper routines
809 
IsSource(const Node * node)810 inline bool IsSource(const Node* node) { return node->IsSource(); }
IsSink(const Node * node)811 inline bool IsSink(const Node* node) { return node->IsSink(); }
IsSwitch(const Node * node)812 inline bool IsSwitch(const Node* node) { return node->IsSwitch(); }
IsMerge(const Node * node)813 inline bool IsMerge(const Node* node) { return node->IsMerge(); }
IsEnter(const Node * node)814 inline bool IsEnter(const Node* node) { return node->IsEnter(); }
IsExit(const Node * node)815 inline bool IsExit(const Node* node) { return node->IsExit(); }
IsNextIteration(const Node * n)816 inline bool IsNextIteration(const Node* n) { return n->IsNextIteration(); }
IsLoopCond(const Node * node)817 inline bool IsLoopCond(const Node* node) { return node->IsLoopCond(); }
IsControlTrigger(const Node * n)818 inline bool IsControlTrigger(const Node* n) { return n->IsControlTrigger(); }
IsSend(const Node * node)819 inline bool IsSend(const Node* node) { return node->IsSend(); }
IsRecv(const Node * node)820 inline bool IsRecv(const Node* node) { return node->IsRecv(); }
IsHostSend(const Node * node)821 inline bool IsHostSend(const Node* node) { return node->IsHostSend(); }
IsHostRecv(const Node * node)822 inline bool IsHostRecv(const Node* node) { return node->IsHostRecv(); }
823 
824 // True for Nodes that mediate the transfer of values between processes.
IsTransferNode(const Node * n)825 inline bool IsTransferNode(const Node* n) { return IsSend(n) || IsRecv(n); }
826 
IsConstant(const Node * node)827 inline bool IsConstant(const Node* node) { return node->IsConstant(); }
IsVariable(const Node * node)828 inline bool IsVariable(const Node* node) { return node->IsVariable(); }
IsIdentity(const Node * node)829 inline bool IsIdentity(const Node* node) { return node->IsIdentity(); }
830 
831 // Returns true iff 'n' is a control flow node.
IsControlFlow(const Node * n)832 inline bool IsControlFlow(const Node* n) { return n->IsControlFlow(); }
833 
834 // Returns true if the node only depends on its input's metadata
835 // (shape).  Specifically, returns true for "Size", "Shape" and "Rank" ops.
IsMetadata(const Node * n)836 inline bool IsMetadata(const Node* n) { return n->IsMetadata(); }
837 
IsScopedAllocator(const Node * n)838 inline bool IsScopedAllocator(const Node* n) { return n->IsScopedAllocator(); }
839 
IsHostMemoryPreserving(const Node * node)840 inline bool IsHostMemoryPreserving(const Node* node) {
841   return IsIdentity(node) || IsControlFlow(node);
842 }
843 
844 // NOTE: We declare Reference type of NodeIter and NeighborIter as Node* (see
845 // https://en.cppreference.com/w/cpp/iterator/iterator).
846 
847 // Iterator for stepping through the nodes of a graph.
848 class NodeIter
849     : public std::iterator<std::forward_iterator_tag, Node, std::ptrdiff_t,
850                            /*Pointer*/ Node*, /*Reference*/ Node*> {
851  public:
852   NodeIter(const Graph* graph, int id);
853   bool operator==(const NodeIter& rhs) const;
854   bool operator!=(const NodeIter& rhs) const;
855   void operator++();
856   reference operator*() const;
857   pointer operator->() const;
858 
859  private:
860   // Invariant: id_ == graph_->num_node_ids() || graph_->FindId(id_) != nullptr
861   const Graph* graph_;
862   int id_;
863 };
864 
865 // Iterator for stepping through the neighbors of a node.
866 class NeighborIter
867     : public std::iterator<std::forward_iterator_tag, Node, std::ptrdiff_t,
868                            /*Pointer*/ Node*, /*Reference*/ Node*> {
869  public:
870   NeighborIter(EdgeSet::const_iterator iter, bool incoming);
871   bool operator==(const NeighborIter& rhs) const;
872   bool operator!=(const NeighborIter& rhs) const;
873   void operator++();
874   reference operator*() const;
875   pointer operator->() const;
876 
877  private:
878   EdgeSet::const_iterator iter_;
879   bool incoming_;
880 };
881 
882 // IMPLEMENTATION DETAILS, PLEASE IGNORE
883 
NodeIter(const Graph * graph,int id)884 inline NodeIter::NodeIter(const Graph* graph, int id)
885     : graph_(graph), id_(id) {}
886 
887 inline bool NodeIter::operator==(const NodeIter& rhs) const {
888   DCHECK(graph_ == rhs.graph_);
889   return id_ == rhs.id_;
890 }
891 
892 inline bool NodeIter::operator!=(const NodeIter& rhs) const {
893   return !(*this == rhs);
894 }
895 
896 inline void NodeIter::operator++() {
897   while (1) {
898     DCHECK_LE(id_, graph_->num_node_ids());
899     ++id_;
900     if (id_ >= graph_->num_node_ids() || graph_->FindNodeId(id_) != nullptr) {
901       return;
902     }
903   }
904 }
905 
906 inline Node* NodeIter::operator*() const { return graph_->FindNodeId(id_); }
907 
908 inline Node* NodeIter::operator->() const { return graph_->FindNodeId(id_); }
909 
NeighborIter(EdgeSet::const_iterator iter,bool incoming)910 inline NeighborIter::NeighborIter(EdgeSet::const_iterator iter, bool incoming)
911     : iter_(iter), incoming_(incoming) {}
912 
913 inline bool NeighborIter::operator==(const NeighborIter& rhs) const {
914   return iter_ == rhs.iter_ && incoming_ == rhs.incoming_;
915 }
916 
917 inline bool NeighborIter::operator!=(const NeighborIter& rhs) const {
918   return !(*this == rhs);
919 }
920 
921 inline void NeighborIter::operator++() { ++iter_; }
922 
923 inline Node* NeighborIter::operator*() const {
924   const Edge* e = *iter_;
925   return incoming_ ? e->src() : e->dst();
926 }
927 
928 inline Node* NeighborIter::operator->() const {
929   const Edge* e = *iter_;
930   return incoming_ ? e->src() : e->dst();
931 }
932 
IsControlEdge()933 inline bool Edge::IsControlEdge() const {
934   // Note that if either src_output_ or dst_input_ is kControlSlot,
935   // so is the other one (AddEdge checks this).
936   return src_output_ == Graph::kControlSlot;
937 }
938 
nodes()939 inline gtl::iterator_range<NodeIter> Graph::nodes() const {
940   // Note that NodeId 0 is always valid since we don't let the source
941   // node be removed from the graph.
942   return gtl::make_range(NodeIter(this, 0), NodeIter(this, num_node_ids()));
943 }
944 
op_nodes()945 inline gtl::iterator_range<NodeIter> Graph::op_nodes() const {
946   // Note that NodeId 0 is always valid since we don't let the source
947   // node be removed from the graph.
948   //
949   // The current implementation of Graph maintains the invariant that the
950   // first two nodes are the source and sink nodes, and all other nodes are op
951   // nodes. This method (op_nodes()) relies on this invariant.
952   NodeIter begin(this, 0);
953   NodeIter end(this, num_node_ids());
954   if (begin != end) {
955     ++begin;
956   }
957   if (begin != end) {
958     ++begin;
959   }
960   return gtl::make_range(begin, end);
961 }
962 
set_assigned_device_name_index(int index)963 inline void Node::set_assigned_device_name_index(int index) {
964   graph_->CheckDeviceNameIndex(index);
965   assigned_device_name_index_ = index;
966 }
967 
set_assigned_device_name(const std::string & device_name)968 inline void Node::set_assigned_device_name(const std::string& device_name) {
969   graph_->set_assigned_device_name(this, device_name);
970 }
971 
assigned_device_name()972 inline const std::string& Node::assigned_device_name() const {
973   return graph_->get_assigned_device_name(*this);
974 }
975 
976 }  // namespace tensorflow
977 
978 #endif  // TENSORFLOW_CORE_GRAPH_GRAPH_H_
979