• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // A Graph describes a set of computations that are to be
17 // performed, as well as the dependencies between those
18 // computations. The basic model is a DAG (directed acyclic graph) with
19 // * internal nodes representing computational operations to be performed;
20 // * edges represent dependencies, indicating the target may only be
21 //   executed once the source has completed; and
22 // * predefined "source" (start) and "sink" (finish) nodes -- the source
23 //   should be the only node that doesn't depend on anything, and the sink
24 //   should be the only node that nothing depends on.
25 //
26 // Note: Node ids are intended to be relatively dense in the
27 // 0..max_id range, but there may be gaps since ids won't be reused.
28 //
29 // Note: Some dependencies between operations are due to one operation
30 // consuming the output of another. In fact operations can produce
31 // multiple outputs and consume multiple inputs, and some
32 // optimizations will care about which specific outputs are connected
33 // to which specific inputs.  We therefore represent data dependency
34 // between output O of layer A and input I of layer B using
35 // "input index" and "output index" labels per edge.
36 
37 #ifndef TENSORFLOW_CORE_GRAPH_GRAPH_H_
38 #define TENSORFLOW_CORE_GRAPH_GRAPH_H_
39 
40 #include <functional>
41 #include <memory>
42 #include <string>
43 #include <vector>
44 
45 #include "absl/types/optional.h"
46 #include "tensorflow/core/framework/full_type.pb.h"
47 #include "tensorflow/core/framework/function.h"
48 #include "tensorflow/core/framework/node_def.pb.h"
49 #include "tensorflow/core/framework/node_def_util.h"
50 #include "tensorflow/core/framework/op.h"
51 #include "tensorflow/core/framework/types.h"
52 #include "tensorflow/core/graph/edgeset.h"
53 #include "tensorflow/core/lib/core/arena.h"
54 #include "tensorflow/core/lib/core/refcount.h"
55 #include "tensorflow/core/lib/core/status.h"
56 #include "tensorflow/core/lib/gtl/iterator_range.h"
57 #include "tensorflow/core/platform/logging.h"
58 #include "tensorflow/core/platform/macros.h"
59 #include "tensorflow/core/platform/stringpiece.h"
60 #include "tensorflow/core/platform/types.h"
61 
62 namespace tensorflow {
63 
64 class Edge;
65 class EdgeSetTest;
66 class Graph;
67 class GraphDef;
68 class Node;
69 struct OutputTensor;
70 class VersionDef;
71 class WhileContext;
72 
73 class NeighborIter;     // Declared below
74 class NodeIter;         // Declared below
75 
76 // Indicates where the graph instance is originated from.
77 enum class ConstructionContext {
78   kNotTracked,     // Not tracked.
79   kDirectSession,  // From `tensorflow::DirectSession`, TF1 session API.
80   kEagerRuntime,   // Registered from TF2 eager runtime.
81 };
82 
83 class Node {
84  public:
85   std::string DebugString() const;
id()86   int id() const { return id_; }
cost_id()87   int cost_id() const { return cost_id_; }
88   const std::string& name() const;
89   void set_name(std::string name);
90   const std::string& type_string() const;
91 
92   // def() provides the NodeDef the user supplied, but the specifics
93   // of this Node may have changed due to placement, optimization, etc.
94   // In particular:
95   // * def().name() will match name();
96   // * def().op() will match type_string() and op_def().name();
97   // * def().input() is not reliable, use "in_edges()" below instead;
98   // * def().device() is the "user's requested device" and may not match
99   //   the actual assigned device, see assigned_device_name() below;
100   // * def().attr() is authoritative.
101   // TODO(irving): Replace with NodeInfo.
102   const NodeDef& def() const;
103   const OpDef& op_def() const;
104 
105   // TODO(mdan): This is only used by control_flow_deps_o_chains. Remove?
106   NodeDef* mutable_def();
107 
108   // input and output types
109   int32 num_inputs() const;
110   DataType input_type(int32_t i) const;
111   const DataTypeVector& input_types() const;
112 
113   int32 num_outputs() const;
114   DataType output_type(int32_t o) const;
115   const DataTypeVector& output_types() const;
116 
117   // The device requested by the user.  For the actual assigned device,
118   // use assigned_device_name() below.
119   const std::string& requested_device() const;
120 
121   // This changes the user requested device but not necessarily the device that
122   // on which the operation will run.
123   void set_requested_device(const std::string& device);
124 
125   // This gives the device the runtime has assigned this node to.  If
126   // you want the device the user requested, use def().device() instead.
127   // TODO(josh11b): Validate that the assigned_device, if not empty:
128   // fully specifies a device, and satisfies def().device().
129   // TODO(josh11b): Move assigned_device_name outside of Node into a
130   // NodeId->DeviceName map.
131   const std::string& assigned_device_name() const;
132   void set_assigned_device_name(const std::string& device_name);
has_assigned_device_name()133   bool has_assigned_device_name() const {
134     return assigned_device_name_index_ > 0;
135   }
assigned_device_name_index()136   int assigned_device_name_index() const { return assigned_device_name_index_; }
137   void set_assigned_device_name_index(int index);
138 
139   // Sets 'original_node_names' field of this node's DebugInfo proto to
140   // 'names'.
141   void set_original_node_names(const std::vector<string>& names);
142 
143   // Read only access to attributes
144   AttrSlice attrs() const;
145 
146   // Inputs requested by the NodeDef.  For the actual inputs, use in_edges.
147   const protobuf::RepeatedPtrField<string>& requested_inputs() const;
148 
149   // Get the neighboring nodes via edges either in or out of this node.  This
150   // includes control edges.
151   gtl::iterator_range<NeighborIter> in_nodes() const;
152   gtl::iterator_range<NeighborIter> out_nodes() const;
in_edges()153   const EdgeSet& in_edges() const { return in_edges_; }
out_edges()154   const EdgeSet& out_edges() const { return out_edges_; }
155 
156   // Node type helpers.
IsSource()157   bool IsSource() const { return id() == 0; }
IsSink()158   bool IsSink() const { return id() == 1; }
159   // Anything other than the special Source & Sink nodes.
IsOp()160   bool IsOp() const { return id() > 1; }
161 
162   // Node class helpers
IsSwitch()163   bool IsSwitch() const { return class_ == NC_SWITCH; }
IsMerge()164   bool IsMerge() const { return class_ == NC_MERGE; }
IsEnter()165   bool IsEnter() const { return class_ == NC_ENTER; }
IsExit()166   bool IsExit() const { return class_ == NC_EXIT; }
IsNextIteration()167   bool IsNextIteration() const { return class_ == NC_NEXT_ITERATION; }
IsLoopCond()168   bool IsLoopCond() const { return class_ == NC_LOOP_COND; }
IsControlTrigger()169   bool IsControlTrigger() const { return class_ == NC_CONTROL_TRIGGER; }
IsSend()170   bool IsSend() const { return class_ == NC_SEND || class_ == NC_HOST_SEND; }
IsRecv()171   bool IsRecv() const { return class_ == NC_RECV || class_ == NC_HOST_RECV; }
IsConstant()172   bool IsConstant() const { return class_ == NC_CONSTANT; }
IsVariable()173   bool IsVariable() const { return class_ == NC_VARIABLE; }
IsIdentity()174   bool IsIdentity() const { return class_ == NC_IDENTITY; }
IsGetSessionHandle()175   bool IsGetSessionHandle() const { return class_ == NC_GET_SESSION_HANDLE; }
IsGetSessionTensor()176   bool IsGetSessionTensor() const { return class_ == NC_GET_SESSION_TENSOR; }
IsDeleteSessionTensor()177   bool IsDeleteSessionTensor() const {
178     return class_ == NC_DELETE_SESSION_TENSOR;
179   }
IsControlFlow()180   bool IsControlFlow() const {
181     return (class_ != NC_OTHER) &&  // Fast path
182            (IsSwitch() || IsMerge() || IsEnter() || IsExit() ||
183             IsNextIteration());
184   }
IsHostSend()185   bool IsHostSend() const { return class_ == NC_HOST_SEND; }
IsHostRecv()186   bool IsHostRecv() const { return class_ == NC_HOST_RECV; }
IsScopedAllocator()187   bool IsScopedAllocator() const { return class_ == NC_SCOPED_ALLOCATOR; }
IsCollective()188   bool IsCollective() const { return class_ == NC_COLLECTIVE; }
189 
IsMetadata()190   bool IsMetadata() const { return class_ == NC_METADATA; }
IsFakeParam()191   bool IsFakeParam() const { return class_ == NC_FAKE_PARAM; }
IsPartitionedCall()192   bool IsPartitionedCall() const { return class_ == NC_PARTITIONED_CALL; }
193 
194   // Returns true if this node is any kind of function call node.
195   //
196   // NOTE: "function call nodes" include partitioned call ops, symbolic gradient
197   // ops, and ops whose type_string is the name of a function ("function ops").
IsFunctionCall()198   bool IsFunctionCall() const {
199     return class_ == NC_PARTITIONED_CALL || class_ == NC_FUNCTION_OP ||
200            class_ == NC_SYMBOLIC_GRADIENT;
201   }
202 
IsIfNode()203   bool IsIfNode() const { return class_ == NC_IF; }
IsWhileNode()204   bool IsWhileNode() const { return class_ == NC_WHILE; }
IsCaseNode()205   bool IsCaseNode() const { return class_ == NC_CASE; }
206   // Is this node a function input
IsArg()207   bool IsArg() const { return class_ == NC_ARG; }
208   // Is this node a function output
IsRetval()209   bool IsRetval() const { return class_ == NC_RETVAL; }
210 
IsDistributedCommunication()211   bool IsDistributedCommunication() const {
212     return op_def().is_distributed_communication();
213   }
214 
215   template <typename T>
AddAttr(const std::string & name,const T & val)216   void AddAttr(const std::string& name, const T& val) {
217     SetAttrValue(val, AddAttrHelper(name));
218     UpdateProperties();
219   }
220 
AddAttr(const std::string & name,std::vector<string> && val)221   void AddAttr(const std::string& name, std::vector<string>&& val) {
222     MoveAttrValue(std::move(val), AddAttrHelper(name));
223     UpdateProperties();
224   }
225 
226   void ClearAttr(const std::string& name);
227 
228   // Returns into '*e' the edge connecting to the 'idx' input of this Node.
229   Status input_edge(int idx, const Edge** e) const;
230 
231   // Returns into '*edges' the input data edges of this Node, indexed by input
232   // number. Does not return control edges.
233   Status input_edges(std::vector<const Edge*>* edges) const;
234 
235   // Returns into '*n' the node that has an output connected to the
236   // 'idx' input of this Node.
237   Status input_node(int idx, const Node** n) const;
238   Status input_node(int idx, Node** n) const;
239 
240   // Returns into '*t' the idx-th input tensor of this node, represented as the
241   // output tensor of input_node(idx).
242   Status input_tensor(int idx, OutputTensor* t) const;
243 
while_ctx()244   WhileContext* while_ctx() const { return while_ctx_; }
set_while_ctx(WhileContext * while_ctx)245   void set_while_ctx(WhileContext* while_ctx) {
246     DCHECK(IsExit());
247     DCHECK(while_ctx_ == nullptr);
248     while_ctx_ = while_ctx;
249   }
250 
properties()251   std::shared_ptr<NodeProperties> properties() const { return props_; }
252 
253   // Sets the stack trace for the node. Assumes that getting and setting the
254   // stack trace for a given node will not race.
SetStackTrace(const std::shared_ptr<AbstractStackTrace> & stack_trace)255   void SetStackTrace(const std::shared_ptr<AbstractStackTrace>& stack_trace) {
256     stack_trace_ = stack_trace;
257   }
258 
259   // Get the stack trace for when the node was instantiated.
GetStackTrace()260   const std::shared_ptr<AbstractStackTrace>& GetStackTrace() const {
261     return stack_trace_;
262   }
263 
264   // Called after an attr has changed. Decides whether we need to update some
265   // property of the node (stored in props_).
266   void UpdateProperties();
267 
268  private:
269   friend class Graph;
270   Node();
271 
272   // Stack trace for the user code for node instantiation. Can be shared across
273   // multiple nodes (e.g. when inlining).
274   std::shared_ptr<AbstractStackTrace> stack_trace_;
275 
276   // Releases memory from props_, in addition to restoring *this to its
277   // uninitialized state.
278   void Clear();
279 
280   // Make a copy of the Node's props_ if props_ is shared with
281   // other nodes. This must be called before mutating properties,
282   // e.g. in AddAttr.
283   void MaybeCopyOnWrite();
284 
285   AttrValue* AddAttrHelper(const std::string& name);
286 
287   // A set of mutually exclusive classes for different kinds of nodes,
288   // class_ is initialized in the Node::Initialize routine based on the
289   // node's type_string().
290   enum NodeClass {
291     NC_UNINITIALIZED,
292     NC_SWITCH,
293     NC_MERGE,
294     NC_ENTER,
295     NC_EXIT,
296     NC_NEXT_ITERATION,
297     NC_LOOP_COND,
298     NC_CONTROL_TRIGGER,
299     NC_SEND,
300     NC_HOST_SEND,
301     NC_RECV,
302     NC_HOST_RECV,
303     NC_CONSTANT,
304     NC_VARIABLE,
305     NC_IDENTITY,
306     NC_GET_SESSION_HANDLE,
307     NC_GET_SESSION_TENSOR,
308     NC_DELETE_SESSION_TENSOR,
309     NC_METADATA,
310     NC_SCOPED_ALLOCATOR,
311     NC_COLLECTIVE,
312     NC_FAKE_PARAM,
313     NC_PARTITIONED_CALL,
314     NC_FUNCTION_OP,
315     NC_SYMBOLIC_GRADIENT,
316     NC_IF,
317     NC_WHILE,
318     NC_CASE,
319     NC_ARG,
320     NC_RETVAL,
321     NC_OTHER  // Not a special kind of node
322   };
323 
324   void Initialize(int id, int cost_id, std::shared_ptr<NodeProperties> props,
325                   NodeClass node_class);
326 
327   static NodeClass GetNodeClassForOp(const std::string& ts);
328 
329   int id_;       // -1 until Initialize() is called
330   int cost_id_;  // -1 if there is no corresponding cost accounting node
331   NodeClass class_;
332 
333   EdgeSet in_edges_;
334   EdgeSet out_edges_;
335 
336   // NOTE(skyewm): inheriting from core::RefCounted may have a slight
337   // performance benefit over using shared_ptr, at the cost of manual ref
338   // counting
339   std::shared_ptr<NodeProperties> props_;
340 
341   // Index within Graph::device_names_ of the name of device assigned
342   // to perform this computation.
343   int assigned_device_name_index_;
344 
345   // A back-pointer to the Graph that owns this node.  Currently, this exists
346   // solely to allow Node::[set_]assigned_device_name() to work. However, if all
347   // callers of Node::[set_]assigned_device_name() are modified to use the
348   // equivalent methods defined directly on Graph, then we can remove this
349   // field and reclaim that memory.
350   Graph* graph_;
351 
352   // Set if this is an exit node of a while loop with an associated
353   // WhileContext. Otherwise null. (This is only set for exit nodes because
354   // they're the first nodes of a loop encountered while creating the gradient
355   // graph. Exit nodes that are part of while loop gradient graphs will not have
356   // this set.)
357   WhileContext* while_ctx_;
358 
359   TF_DISALLOW_COPY_AND_ASSIGN(Node);
360 };
361 
362 // Stores debug information associated with the Node.
363 struct NodeDebugInfo {
364   const std::string name;
365   std::vector<string> original_node_names;
366 
367   NodeDebugInfo(const Node& n);
368   NodeDebugInfo(const NodeDef& ndef);
369   NodeDebugInfo(StringPiece node_name, bool has_experimental_debug_info,
370                 const NodeDef_ExperimentalDebugInfo& experimental_debug_info);
371 };
372 
373 // Represents an input of a node, i.e., the `index`-th input to `node`.
374 struct InputTensor {
375   Node* node;
376   int index;
377 
InputTensorInputTensor378   InputTensor(Node* n, int i) : node(n), index(i) {}
InputTensorInputTensor379   InputTensor() : node(nullptr), index(0) {}
380 
381   // Returns true if this InputTensor is identical to 'other'. Nodes are
382   // compared using pointer equality.
383   bool operator==(const InputTensor& other) const;
384 
385   // A hash function for InputTensors. Nodes are hashed based on their pointer
386   // value.
387   struct Hash {
388     uint64 operator()(InputTensor const& s) const;
389   };
390 };
391 
392 // Represents an output of a node, i.e., the `index`-th output of `node`. Note
393 // that a single `OutputTensor` can correspond to multiple `Edge`s if the output
394 // is consumed by multiple destination nodes.
395 struct OutputTensor {
396   Node* node;
397   int index;
398 
OutputTensorOutputTensor399   OutputTensor(Node* n, int i) : node(n), index(i) {}
OutputTensorOutputTensor400   OutputTensor() : node(nullptr), index(0) {}
401 
402   // Returns true if this OutputTensor is identical to 'other'. Nodes are
403   // compared using pointer equality.
404   bool operator==(const OutputTensor& other) const;
405 
406   // A hash function for OutputTensors. Nodes are hashed based on their pointer
407   // value.
408   struct Hash {
409     uint64 operator()(OutputTensor const& s) const;
410   };
411 };
412 
413 class Edge {
414  public:
src()415   Node* src() const { return src_; }
dst()416   Node* dst() const { return dst_; }
id()417   int id() const { return id_; }
418 
419   // Return the index of the source output that produces the data
420   // carried by this edge.  The special value kControlSlot is used
421   // for control dependencies.
src_output()422   int src_output() const { return src_output_; }
423 
424   // Return the index of the destination input that consumes the data
425   // carried by this edge.  The special value kControlSlot is used
426   // for control dependencies.
dst_input()427   int dst_input() const { return dst_input_; }
428 
429   // Return true iff this is an edge that indicates a control-flow
430   // (as opposed to a data-flow) dependency.
431   bool IsControlEdge() const;
432 
433   std::string DebugString() const;
434 
435  private:
Edge()436   Edge() {}
437 
438   friend class EdgeSetTest;
439   friend class Graph;
440   Node* src_;
441   Node* dst_;
442   int id_;
443   int src_output_;
444   int dst_input_;
445 };
446 
447 // Allows for iteration of the edges of a Graph, by iterating the underlying
448 // Graph.edges_ vector while skipping over null entries.
449 class GraphEdgesIterable {
450  private:
451   const std::vector<Edge*>& edges_;
452 
453  public:
GraphEdgesIterable(const std::vector<Edge * > & edges)454   explicit GraphEdgesIterable(const std::vector<Edge*>& edges)
455       : edges_(edges) {}
456 
457   typedef Edge* value_type;
458 
459   class const_iterator {
460    private:
461     // The underlying iterator.
462     std::vector<value_type>::const_iterator iter_;
463 
464     // The end of the underlying iterator.
465     std::vector<value_type>::const_iterator end_;
466 
467     // Advances iter_ until it reaches a non-null item, or reaches the end.
apply_filter()468     void apply_filter() {
469       while (iter_ != end_ && *iter_ == nullptr) {
470         ++iter_;
471       }
472     }
473 
474    public:
const_iterator(std::vector<value_type>::const_iterator iter,std::vector<value_type>::const_iterator end)475     const_iterator(std::vector<value_type>::const_iterator iter,
476                    std::vector<value_type>::const_iterator end)
477         : iter_(iter), end_(end) {
478       apply_filter();
479     }
480 
481     bool operator==(const const_iterator& other) const {
482       return iter_ == other.iter_;
483     }
484 
485     bool operator!=(const const_iterator& other) const {
486       return iter_ != other.iter_;
487     }
488 
489     // This is the prefix increment operator (++x), which is the operator
490     // used by C++ range iteration (for (x : y) ...).  We intentionally do not
491     // provide a postfix increment operator.
492     const_iterator& operator++() {
493       ++iter_;
494       apply_filter();
495       return *this;
496     }
497 
498     value_type operator*() { return *iter_; }
499   };
500 
begin()501   const_iterator begin() {
502     return const_iterator(edges_.begin(), edges_.end());
503   }
end()504   const_iterator end() { return const_iterator(edges_.end(), edges_.end()); }
505 };
506 
507 // Thread compatible but not thread safe.
508 class Graph {
509  public:
510   // Constructs a graph with a single SOURCE (always id kSourceId) and a
511   // single SINK (always id kSinkId) node, and an edge from SOURCE->SINK.
512   //
513   // The graph can hold ops found in the registry. `ops`s lifetime must be at
514   // least that of the constructed graph's.
515   explicit Graph(const OpRegistryInterface* ops);
516 
517   // Constructs a graph with a single SOURCE (always id kSourceId) and a
518   // single SINK (always id kSinkId) node, and an edge from SOURCE->SINK.
519   //
520   // The graph can hold ops found in `flib_def`. Unlike the constructor taking
521   // an OpRegistryInterface, this constructor copies the function definitions in
522   // `flib_def` so its lifetime may be shorter than that of the graph's. The
523   // OpRegistryInterface backing `flib_def` must still have the lifetime of the
524   // graph though.
525   explicit Graph(const FunctionLibraryDefinition& flib_def);
526 
527   ~Graph();
528 
529   // Clone the current graph into a new one.
530   std::unique_ptr<Graph> Clone();
531 
532   static const int kControlSlot;
533 
534   // The GraphDef version range of this graph (see graph.proto).
535   const VersionDef& versions() const;
536   void set_versions(const VersionDef& versions);
537 
538   // Adds a new node to this graph, and returns it. Infers the Op and
539   // input/output types for the node. *this owns the returned instance.
540   // Returns nullptr and sets *status on error.
541   Node* AddNode(NodeDef node_def, Status* status);
542 
543   // Copies *node, which may belong to another graph, to a new node,
544   // which is returned.  Does not copy any edges.  *this owns the
545   // returned instance.
546   Node* CopyNode(const Node* node);
547 
548   // Removes a node from this graph, including all edges from or to it.
549   // *node should not be accessed after calling this function.
550   // REQUIRES: node->IsOp()
551   void RemoveNode(Node* node);
552 
553   void Copy(const Graph& src);
554 
555   // Adds an edge that connects the xth output of `source` to the yth input of
556   // `dest` and returns it. Does not update dest's NodeDef.
557   const Edge* AddEdge(Node* source, int x, Node* dest, int y);
558 
559   // Adds a control edge (no data flows along this edge) that connects `source`
560   // to `dest`. If `dest`s NodeDef is missing the corresponding control input,
561   // adds the control input.
562   //
563   // If such a control edge already exists and `allow_duplicates` is false, no
564   // edge is added and the function returns nullptr. Otherwise the edge is
565   // unconditionally created and returned. The NodeDef is not updated if
566   // `allow_duplicates` is true.
567   // TODO(skyewm): // TODO(skyewm): allow_duplicates is needed only by
568   // graph_partition.cc. Figure out if we can do away with it.
569   const Edge* AddControlEdge(Node* source, Node* dest,
570                              bool allow_duplicates = false);
571 
572   // Removes edge from the graph. Does not update the destination node's
573   // NodeDef.
574   // REQUIRES: The edge must exist.
575   void RemoveEdge(const Edge* edge);
576 
577   // Removes control edge `edge` from the graph. Note that this also updates
578   // the corresponding NodeDef to reflect the change.
579   // REQUIRES: The control edge must exist.
580   void RemoveControlEdge(const Edge* e);
581 
582   // Updates the input to a node.  The existing edge to `dst` is removed and an
583   // edge from `new_src` to `dst` is created. The NodeDef associated with `dst`
584   // is also updated.
585   Status UpdateEdge(Node* new_src, int new_src_index, Node* dst, int dst_index);
586 
587   // Like AddEdge but updates dst's NodeDef. Used to add an input edge to a
588   // "While" op during gradient construction, see AddInputWhileHack in
589   // python_api.h for more details.
590   Status AddWhileInputHack(Node* new_src, int new_src_index, Node* dst);
591 
592   // Adds the function and gradient definitions in `fdef_lib` to this graph's op
593   // registry. Ignores duplicate functions, and returns a bad status if an
594   // imported function differs from an existing function or op with the same
595   // name.
596   Status AddFunctionLibrary(const FunctionDefLibrary& fdef_lib);
597 
598   // The number of live nodes in the graph.
599   //
600   // Because nodes can be removed from the graph, num_nodes() is often
601   // smaller than num_node_ids(). If one needs to create an array of
602   // nodes indexed by node ids, num_node_ids() should be used as the
603   // array's size.
num_nodes()604   int num_nodes() const { return num_nodes_; }
605 
606   // The number of live nodes in the graph, excluding the Source and Sink nodes.
num_op_nodes()607   int num_op_nodes() const {
608     DCHECK_GE(num_nodes_, 2);
609     return num_nodes_ - 2;
610   }
611 
612   // The number of live edges in the graph.
613   //
614   // Because edges can be removed from the graph, num_edges() is often
615   // smaller than num_edge_ids(). If one needs to create an array of
616   // edges indexed by edge ids, num_edge_ids() should be used as the
617   // array's size.
num_edges()618   int num_edges() const { return num_edges_; }
619 
620   // Serialize the nodes starting at `from_node_id` to a GraphDef.
621   void ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const;
622 
623   // Serialize to a GraphDef.
624   void ToGraphDef(GraphDef* graph_def) const;
625 
626   // This version can be called from debugger to inspect the graph content.
627   // Use the previous version outside debug context for efficiency reasons.
628   //
629   // Note: We do not expose a DebugString() API, since GraphDef.DebugString() is
630   // not defined in some TensorFlow builds.
631   GraphDef ToGraphDefDebug() const;
632 
633   // Generate new node name with the specified prefix that is unique
634   // across this graph.
635   std::string NewName(StringPiece prefix);
636 
637   // Access to the list of all nodes.  Example usage:
638   //   for (Node* node : graph.nodes()) { ... }
639   gtl::iterator_range<NodeIter> nodes() const;
640 
641   // Access to the list of all nodes, excluding the Source and Sink nodes.
642   gtl::iterator_range<NodeIter> op_nodes() const;
643 
644   // Returns one more than the maximum id assigned to any node.
num_node_ids()645   int num_node_ids() const { return nodes_.size(); }
646 
647   // Returns the node associated with an id, or nullptr if no node
648   // with that id (the node with that id was removed and the id has
649   // not yet been re-used). *this owns the returned instance.
650   // REQUIRES: 0 <= id < num_node_ids().
FindNodeId(int id)651   Node* FindNodeId(int id) const { return nodes_[id]; }
652 
653   // Returns one more than the maximum id assigned to any edge.
num_edge_ids()654   int num_edge_ids() const { return edges_.size(); }
655 
656   // Returns the Edge associated with an id, or nullptr if no edge
657   // with that id (the edge with that id was removed and the id has
658   // not yet been re-used). *this owns the returned instance.
659   // REQUIRES: 0 <= id < num_edge_ids().
FindEdgeId(int id)660   const Edge* FindEdgeId(int id) const { return edges_[id]; }
661 
662   // Access to the set of all edges.  Example usage:
663   //   for (const Edge* e : graph.edges()) { ... }
edges()664   GraphEdgesIterable edges() const { return GraphEdgesIterable(edges_); }
665 
666   // The pre-defined nodes.
667   enum { kSourceId = 0, kSinkId = 1 };
source_node()668   Node* source_node() const { return FindNodeId(kSourceId); }
sink_node()669   Node* sink_node() const { return FindNodeId(kSinkId); }
670 
op_registry()671   const OpRegistryInterface* op_registry() const { return &ops_; }
flib_def()672   const FunctionLibraryDefinition& flib_def() const { return ops_; }
673 
674   // TODO(mdan): This is only used by control_flow_deps_o_chains. Remove?
mutable_flib_def()675   FunctionLibraryDefinition* mutable_flib_def() { return &ops_; }
676 
CheckDeviceNameIndex(int index)677   void CheckDeviceNameIndex(int index) {
678     DCHECK_GE(index, 0);
679     DCHECK_LT(index, static_cast<int>(device_names_.size()));
680   }
681 
682   int InternDeviceName(const std::string& device_name);
683 
get_assigned_device_name(const Node & node)684   const std::string& get_assigned_device_name(const Node& node) const {
685     return device_names_[node.assigned_device_name_index()];
686   }
687 
set_assigned_device_name_index(Node * node,int device_name_index)688   void set_assigned_device_name_index(Node* node, int device_name_index) {
689     CheckDeviceNameIndex(device_name_index);
690     node->assigned_device_name_index_ = device_name_index;
691   }
692 
set_assigned_device_name(Node * node,const std::string & device_name)693   void set_assigned_device_name(Node* node, const std::string& device_name) {
694     node->assigned_device_name_index_ = InternDeviceName(device_name);
695   }
696 
697   // Returns OK if `node` is non-null and belongs to this graph
698   Status IsValidNode(const Node* node) const;
699 
700   // Returns OK if IsValidNode(`node`) and `idx` is a valid output.  Does not
701   // accept control outputs.
702   Status IsValidOutputTensor(const Node* node, int idx) const;
703 
704   // Returns OK if IsValidNode(`node`) and `idx` a valid input.  Does not accept
705   // control inputs.
706   Status IsValidInputTensor(const Node* node, int idx) const;
707 
708   // Create and return a new WhileContext owned by this graph. This is called
709   // when a new while loop is created. `frame_name` must be unique among
710   // WhileContexts in this graph.
711   Status AddWhileContext(StringPiece frame_name, std::vector<Node*> enter_nodes,
712                          std::vector<Node*> exit_nodes,
713                          OutputTensor cond_output,
714                          std::vector<OutputTensor> body_inputs,
715                          std::vector<OutputTensor> body_outputs,
716                          WhileContext** result);
717 
718   // Builds a node name to node pointer index for all nodes in the graph.
719   std::unordered_map<string, Node*> BuildNodeNameIndex() const;
720 
GetConstArgIndicesCache()721   absl::optional<std::vector<bool>>& GetConstArgIndicesCache() const {
722     return const_arg_indices_cache_;
723   }
724 
725   // TODO(kkb): Add to the constructor when it becomes managable.
726   // Sets the graph construction context.
SetConstructionContext(ConstructionContext construction_context)727   void SetConstructionContext(ConstructionContext construction_context) {
728     construction_context_ = construction_context;
729   }
730 
731   // TODO(kkb): Rename to `GetConstructionContext` once we're comfortable
732   // making this stable and make it available widely.
733   // Returns the graph construction context. It's `kUnknown` if not set.
GetConstructionContextInternal()734   ConstructionContext GetConstructionContextInternal() const {
735     return construction_context_;
736   }
737 
738   void SetNodeType(StringPiece name, const FullTypeDef& type);
739 
740   void NodeType(StringPiece name, FullTypeDef** result);
741 
742   // TODO(josh11b): uint64 hash() const;
743 
744  private:
745   // If cost_node is non-null, then cost accounting (in CostModel)
746   // will be associated with that node rather than the new one being
747   // created.
748   //
749   // Ownership of the returned Node is not transferred to caller.
750   Node* AllocateNode(std::shared_ptr<NodeProperties> props,
751                      const Node* cost_node, Node::NodeClass node_class);
752   void ReleaseNode(Node* node);
753   // Insert edge in free_edges_ for possible reuse.
754   void RecycleEdge(const Edge* edge);
755   // Registry of all known ops, including functions.
756   FunctionLibraryDefinition ops_;
757 
758   // GraphDef versions
759   const std::unique_ptr<VersionDef> versions_;
760 
761   // Allocator which will give us good locality.
762   core::Arena arena_;
763 
764   // Map from node ids to allocated nodes.  nodes_[id] may be nullptr if
765   // the node with that id was removed from the graph.
766   std::vector<Node*> nodes_;
767 
768   // Types table.
769   // TODO(mdan): Do not store these here. Instead, keep in a GraphDef field.
770   std::unordered_set<TypeRef, TypeHasher> types_;
771 
772   // Experimental.
773   // Map from node node names to their outputs' FullType. Typically, the values
774   // in this map are identical to those in types_, but that is not enforced or
775   // guaranteed.
776   //
777   // The full type specification combines a Tensor's dtype, tensor_shape,
778   // variant_val, etc. into a unified representation.
779   // This definition may only contain concrete types (for example,
780   // Tensor<TypeVar<'T'>> is not a valid node type).
781   //
782   // Presently, FullType duplicates any information found in `dtype`. When set,
783   // it is always consistent with `dtype`. Eventually, `dtype` will be merged
784   // with FullType.
785   //
786   // For example, if a TensorProto has `dtype=DT_INT32`, then
787   // `full_type=FT_TENSOR[FT_INT32]`.
788   // TODO(mdan): Do not store these here. Instead, keep in a GraphDef field.
789   std::unordered_map<string, TypeRef> node_name_to_out_type_;
790 
791   // Number of nodes alive.
792   int64 num_nodes_ = 0;
793 
794   // Map from edge ids to allocated edges.  edges_[id] may be nullptr if
795   // the edge with that id was removed from the graph.
796   std::vector<Edge*> edges_;
797 
798   // The number of entries in edges_ that are not nullptr.
799   int num_edges_ = 0;
800 
801   // Allocated but free nodes and edges.
802   std::vector<Node*> free_nodes_;
803   std::vector<Edge*> free_edges_;
804 
805   // For generating unique names.
806   int name_counter_ = 0;
807 
808   // In most graphs, the number of unique values used for the
809   // Node::assigned_device_name() property is quite small.  If the graph is
810   // large, then this duplication of values can consume a significant amount of
811   // memory.  Instead, we represent the same information using an interning
812   // table, which consists of a vector of unique strings (device_names_), as
813   // well a map (device_names_map_) from unique strings to indices within the
814   // unique string table.
815   //
816   // The InternDeviceName() method handles adding a new entry into the table,
817   // or locating the index of an existing entry.
818   //
819   // The fact that Node::assigned_device_name() is implemented using an
820   // interning table is intentionally public.  This allows algorithms that
821   // frequently access this field to do so efficiently, especially for the case
822   // where the assigned_device_name of one Node is copied directly from that
823   // of another Node.
824 
825   // A table of the unique assigned device names.  Indices do NOT correspond
826   // to node IDs.  Index 0 is always the empty string.
827   std::vector<string> device_names_;
828 
829   // Maps unique device names to indices within device_names_[i].
830   std::unordered_map<string, int> device_names_map_;
831 
832   // All the while contexts owned by this graph, keyed by frame name,
833   // corresponding to all the while loops contained in this graph (including
834   // nested loops). The stored contexts are usually accessed via
835   // AddWhileContext() or Node::while_ctx(), but this manages the lifetime.
836   std::map<string, WhileContext> while_ctxs_;
837 
838   // Cache of the indices of the arguments which need to be constant for the XLA
839   // compilation.
840   mutable absl::optional<std::vector<bool>> const_arg_indices_cache_;
841 
842   // Indicates the context that this Graph instance is constructed.
843   ConstructionContext construction_context_ = ConstructionContext::kNotTracked;
844 
845   TF_DISALLOW_COPY_AND_ASSIGN(Graph);
846 };
847 
848 // TODO(josh11b): We may want to support keeping an index on various
849 // node/edge attributes in a graph, particularly node names.
850 
851 // Helper routines
852 
IsSource(const Node * node)853 inline bool IsSource(const Node* node) { return node->IsSource(); }
IsSink(const Node * node)854 inline bool IsSink(const Node* node) { return node->IsSink(); }
IsSwitch(const Node * node)855 inline bool IsSwitch(const Node* node) { return node->IsSwitch(); }
IsMerge(const Node * node)856 inline bool IsMerge(const Node* node) { return node->IsMerge(); }
IsEnter(const Node * node)857 inline bool IsEnter(const Node* node) { return node->IsEnter(); }
IsExit(const Node * node)858 inline bool IsExit(const Node* node) { return node->IsExit(); }
IsNextIteration(const Node * n)859 inline bool IsNextIteration(const Node* n) { return n->IsNextIteration(); }
IsLoopCond(const Node * node)860 inline bool IsLoopCond(const Node* node) { return node->IsLoopCond(); }
IsControlTrigger(const Node * n)861 inline bool IsControlTrigger(const Node* n) { return n->IsControlTrigger(); }
IsSend(const Node * node)862 inline bool IsSend(const Node* node) { return node->IsSend(); }
IsRecv(const Node * node)863 inline bool IsRecv(const Node* node) { return node->IsRecv(); }
IsHostSend(const Node * node)864 inline bool IsHostSend(const Node* node) { return node->IsHostSend(); }
IsHostRecv(const Node * node)865 inline bool IsHostRecv(const Node* node) { return node->IsHostRecv(); }
866 
867 // True for Nodes that mediate the transfer of values between processes.
IsTransferNode(const Node * n)868 inline bool IsTransferNode(const Node* n) { return IsSend(n) || IsRecv(n); }
869 
IsConstant(const Node * node)870 inline bool IsConstant(const Node* node) { return node->IsConstant(); }
IsVariable(const Node * node)871 inline bool IsVariable(const Node* node) { return node->IsVariable(); }
IsIdentity(const Node * node)872 inline bool IsIdentity(const Node* node) { return node->IsIdentity(); }
873 
874 // Returns true iff 'n' is a control flow node.
IsControlFlow(const Node * n)875 inline bool IsControlFlow(const Node* n) { return n->IsControlFlow(); }
876 
877 // Returns true if the node only depends on its input's metadata
878 // (shape).  Specifically, returns true for "Size", "Shape" and "Rank" ops.
IsMetadata(const Node * n)879 inline bool IsMetadata(const Node* n) { return n->IsMetadata(); }
880 
IsScopedAllocator(const Node * n)881 inline bool IsScopedAllocator(const Node* n) { return n->IsScopedAllocator(); }
882 
IsHostMemoryPreserving(const Node * node)883 inline bool IsHostMemoryPreserving(const Node* node) {
884   return IsIdentity(node) || IsControlFlow(node);
885 }
886 
IsDistributedCommunication(const Node * n)887 inline bool IsDistributedCommunication(const Node* n) {
888   return n->IsDistributedCommunication();
889 }
890 
891 // NOTE: We declare Reference type of NodeIter and NeighborIter as Node* (see
892 // https://en.cppreference.com/w/cpp/iterator/iterator).
893 
894 // Iterator for stepping through the nodes of a graph.
895 class NodeIter
896     : public std::iterator<std::forward_iterator_tag, Node, std::ptrdiff_t,
897                            /*Pointer*/ Node*, /*Reference*/ Node*> {
898  public:
899   NodeIter(const Graph* graph, int id);
900   bool operator==(const NodeIter& rhs) const;
901   bool operator!=(const NodeIter& rhs) const;
902   void operator++();
903   reference operator*() const;
904   pointer operator->() const;
905 
906  private:
907   // Invariant: id_ == graph_->num_node_ids() || graph_->FindId(id_) != nullptr
908   const Graph* graph_;
909   int id_;
910 };
911 
912 // Iterator for stepping through the neighbors of a node.
913 class NeighborIter
914     : public std::iterator<std::forward_iterator_tag, Node, std::ptrdiff_t,
915                            /*Pointer*/ Node*, /*Reference*/ Node*> {
916  public:
917   NeighborIter(EdgeSet::const_iterator iter, bool incoming);
918   bool operator==(const NeighborIter& rhs) const;
919   bool operator!=(const NeighborIter& rhs) const;
920   void operator++();
921   reference operator*() const;
922   pointer operator->() const;
923 
924  private:
925   EdgeSet::const_iterator iter_;
926   bool incoming_;
927 };
928 
929 // IMPLEMENTATION DETAILS, PLEASE IGNORE
930 
NodeIter(const Graph * graph,int id)931 inline NodeIter::NodeIter(const Graph* graph, int id)
932     : graph_(graph), id_(id) {}
933 
934 inline bool NodeIter::operator==(const NodeIter& rhs) const {
935   DCHECK(graph_ == rhs.graph_);
936   return id_ == rhs.id_;
937 }
938 
939 inline bool NodeIter::operator!=(const NodeIter& rhs) const {
940   return !(*this == rhs);
941 }
942 
943 inline void NodeIter::operator++() {
944   while (1) {
945     DCHECK_LE(id_, graph_->num_node_ids());
946     ++id_;
947     if (id_ >= graph_->num_node_ids() || graph_->FindNodeId(id_) != nullptr) {
948       return;
949     }
950   }
951 }
952 
953 inline Node* NodeIter::operator*() const { return graph_->FindNodeId(id_); }
954 
955 inline Node* NodeIter::operator->() const { return graph_->FindNodeId(id_); }
956 
NeighborIter(EdgeSet::const_iterator iter,bool incoming)957 inline NeighborIter::NeighborIter(EdgeSet::const_iterator iter, bool incoming)
958     : iter_(iter), incoming_(incoming) {}
959 
960 inline bool NeighborIter::operator==(const NeighborIter& rhs) const {
961   return iter_ == rhs.iter_ && incoming_ == rhs.incoming_;
962 }
963 
964 inline bool NeighborIter::operator!=(const NeighborIter& rhs) const {
965   return !(*this == rhs);
966 }
967 
968 inline void NeighborIter::operator++() { ++iter_; }
969 
970 inline Node* NeighborIter::operator*() const {
971   const Edge* e = *iter_;
972   return incoming_ ? e->src() : e->dst();
973 }
974 
975 inline Node* NeighborIter::operator->() const {
976   const Edge* e = *iter_;
977   return incoming_ ? e->src() : e->dst();
978 }
979 
IsControlEdge()980 inline bool Edge::IsControlEdge() const {
981   // Note that if either src_output_ or dst_input_ is kControlSlot,
982   // so is the other one (AddEdge checks this).
983   return src_output_ == Graph::kControlSlot;
984 }
985 
nodes()986 inline gtl::iterator_range<NodeIter> Graph::nodes() const {
987   // Note that NodeId 0 is always valid since we don't let the source
988   // node be removed from the graph.
989   return gtl::make_range(NodeIter(this, 0), NodeIter(this, num_node_ids()));
990 }
991 
op_nodes()992 inline gtl::iterator_range<NodeIter> Graph::op_nodes() const {
993   // Note that NodeId 0 is always valid since we don't let the source
994   // node be removed from the graph.
995   //
996   // The current implementation of Graph maintains the invariant that the
997   // first two nodes are the source and sink nodes, and all other nodes are op
998   // nodes. This method (op_nodes()) relies on this invariant.
999   NodeIter begin(this, 0);
1000   NodeIter end(this, num_node_ids());
1001   if (begin != end) {
1002     ++begin;
1003   }
1004   if (begin != end) {
1005     ++begin;
1006   }
1007   return gtl::make_range(begin, end);
1008 }
1009 
set_assigned_device_name_index(int index)1010 inline void Node::set_assigned_device_name_index(int index) {
1011   graph_->CheckDeviceNameIndex(index);
1012   assigned_device_name_index_ = index;
1013 }
1014 
set_assigned_device_name(const std::string & device_name)1015 inline void Node::set_assigned_device_name(const std::string& device_name) {
1016   graph_->set_assigned_device_name(this, device_name);
1017 }
1018 
assigned_device_name()1019 inline const std::string& Node::assigned_device_name() const {
1020   return graph_->get_assigned_device_name(*this);
1021 }
1022 
1023 }  // namespace tensorflow
1024 
1025 #endif  // TENSORFLOW_CORE_GRAPH_GRAPH_H_
1026