• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // A Graph describes a set of computations that are to be
17 // performed, as well as the dependencies between those
18 // computations. The basic model is a DAG (directed acyclic graph) with
19 // * internal nodes representing computational operations to be performed;
20 // * edges represent dependencies, indicating the target may only be
21 //   executed once the source has completed; and
22 // * predefined "source" (start) and "sink" (finish) nodes -- the source
23 //   should be the only node that doesn't depend on anything, and the sink
24 //   should be the only node that nothing depends on.
25 //
26 // Note: Node ids are intended to be relatively dense in the
27 // 0..max_id range, but there may be gaps since ids won't be reused.
28 //
29 // Note: Some dependencies between operations are due to one operation
30 // consuming the output of another. In fact operations can produce
31 // multiple outputs and consume multiple inputs, and some
32 // optimizations will care about which specific outputs are connected
33 // to which specific inputs.  We therefore represent data dependency
34 // between output O of layer A and input I of layer B using
35 // "input index" and "output index" labels per edge.
36 
37 #ifndef TENSORFLOW_CORE_GRAPH_GRAPH_H_
38 #define TENSORFLOW_CORE_GRAPH_GRAPH_H_
39 
40 #include <functional>
41 #include <memory>
42 #include <string>
43 #include <vector>
44 
45 #include "absl/types/optional.h"
46 #include "tensorflow/core/framework/full_type.pb.h"
47 #include "tensorflow/core/framework/function.h"
48 #include "tensorflow/core/framework/node_def.pb.h"
49 #include "tensorflow/core/framework/node_def_util.h"
50 #include "tensorflow/core/framework/op.h"
51 #include "tensorflow/core/framework/types.h"
52 #include "tensorflow/core/graph/edgeset.h"
53 #include "tensorflow/core/lib/core/arena.h"
54 #include "tensorflow/core/lib/core/refcount.h"
55 #include "tensorflow/core/lib/core/status.h"
56 #include "tensorflow/core/lib/gtl/iterator_range.h"
57 #include "tensorflow/core/platform/logging.h"
58 #include "tensorflow/core/platform/macros.h"
59 #include "tensorflow/core/platform/stringpiece.h"
60 #include "tensorflow/core/platform/types.h"
61 
62 namespace tensorflow {
63 
64 class Edge;
65 class EdgeSetTest;
66 class Graph;
67 class GraphDef;
68 class Node;
69 struct OutputTensor;
70 class VersionDef;
71 class WhileContext;
72 
73 class NeighborIter;     // Declared below
74 class NodeIter;         // Declared below
75 
76 // Indicates where the graph instance is originated from.
77 enum class ConstructionContext {
78   kNotTracked,     // Not tracked.
79   kDirectSession,  // From `tensorflow::DirectSession`, TF1 session API.
80   kEagerRuntime,   // Registered from TF2 eager runtime.
81 };
82 
83 class Node {
84  public:
85   std::string DebugString() const;
id()86   int id() const { return id_; }
cost_id()87   int cost_id() const { return cost_id_; }
88   const std::string& name() const;
89   void set_name(std::string name);
90   const std::string& type_string() const;
91 
92   // def() provides the NodeDef the user supplied, but the specifics
93   // of this Node may have changed due to placement, optimization, etc.
94   // In particular:
95   // * def().name() will match name();
96   // * def().op() will match type_string() and op_def().name();
97   // * def().input() is not reliable, use "in_edges()" below instead;
98   // * def().device() is the "user's requested device" and may not match
99   //   the actual assigned device, see assigned_device_name() below;
100   // * def().attr() is authoritative.
101   // TODO(irving): Replace with NodeInfo.
102   const NodeDef& def() const;
103   const OpDef& op_def() const;
104 
105   // TODO(mdan): This is only used by control_flow_deps_o_chains. Remove?
106   NodeDef* mutable_def();
107 
108   // input and output types
109   int32 num_inputs() const;
110   DataType input_type(int32_t i) const;
111   const DataTypeVector& input_types() const;
112 
113   int32 num_outputs() const;
114   DataType output_type(int32_t o) const;
115   const DataTypeVector& output_types() const;
116 
117   // The device requested by the user.  For the actual assigned device,
118   // use assigned_device_name() below.
119   const std::string& requested_device() const;
120 
121   // This changes the user requested device but not necessarily the device that
122   // on which the operation will run.
123   void set_requested_device(const std::string& device);
124 
125   // This gives the device the runtime has assigned this node to.  If
126   // you want the device the user requested, use def().device() instead.
127   // TODO(josh11b): Validate that the assigned_device, if not empty:
128   // fully specifies a device, and satisfies def().device().
129   // TODO(josh11b): Move assigned_device_name outside of Node into a
130   // NodeId->DeviceName map.
131   const std::string& assigned_device_name() const;
132   void set_assigned_device_name(const std::string& device_name);
has_assigned_device_name()133   bool has_assigned_device_name() const {
134     return assigned_device_name_index_ > 0;
135   }
assigned_device_name_index()136   int assigned_device_name_index() const { return assigned_device_name_index_; }
137   void set_assigned_device_name_index(int index);
138 
139   // Sets 'original_node_names' field of this node's DebugInfo proto to
140   // 'names'.
141   void set_original_node_names(const std::vector<string>& names);
142   void set_original_func_names(const std::vector<string>& names);
143 
144   // Read only access to attributes
145   AttrSlice attrs() const;
146 
147   // Inputs requested by the NodeDef.  For the actual inputs, use in_edges.
148   const protobuf::RepeatedPtrField<string>& requested_inputs() const;
149 
150   // Get the neighboring nodes via edges either in or out of this node.  This
151   // includes control edges.
152   gtl::iterator_range<NeighborIter> in_nodes() const;
153   gtl::iterator_range<NeighborIter> out_nodes() const;
in_edges()154   const EdgeSet& in_edges() const { return in_edges_; }
out_edges()155   const EdgeSet& out_edges() const { return out_edges_; }
156 
157   // Node type helpers.
IsSource()158   bool IsSource() const { return id() == 0; }
IsSink()159   bool IsSink() const { return id() == 1; }
160   // Anything other than the special Source & Sink nodes.
IsOp()161   bool IsOp() const { return id() > 1; }
162 
163   // Node class helpers
IsSwitch()164   bool IsSwitch() const { return class_ == NC_SWITCH; }
IsMerge()165   bool IsMerge() const { return class_ == NC_MERGE; }
IsEnter()166   bool IsEnter() const { return class_ == NC_ENTER; }
IsExit()167   bool IsExit() const { return class_ == NC_EXIT; }
IsNextIteration()168   bool IsNextIteration() const { return class_ == NC_NEXT_ITERATION; }
IsLoopCond()169   bool IsLoopCond() const { return class_ == NC_LOOP_COND; }
IsControlTrigger()170   bool IsControlTrigger() const { return class_ == NC_CONTROL_TRIGGER; }
IsSend()171   bool IsSend() const { return class_ == NC_SEND || class_ == NC_HOST_SEND; }
IsRecv()172   bool IsRecv() const { return class_ == NC_RECV || class_ == NC_HOST_RECV; }
IsConstant()173   bool IsConstant() const { return class_ == NC_CONSTANT; }
IsVariable()174   bool IsVariable() const { return class_ == NC_VARIABLE; }
IsIdentity()175   bool IsIdentity() const { return class_ == NC_IDENTITY; }
IsGetSessionHandle()176   bool IsGetSessionHandle() const { return class_ == NC_GET_SESSION_HANDLE; }
IsGetSessionTensor()177   bool IsGetSessionTensor() const { return class_ == NC_GET_SESSION_TENSOR; }
IsDeleteSessionTensor()178   bool IsDeleteSessionTensor() const {
179     return class_ == NC_DELETE_SESSION_TENSOR;
180   }
IsControlFlow()181   bool IsControlFlow() const {
182     return (class_ != NC_OTHER) &&  // Fast path
183            (IsSwitch() || IsMerge() || IsEnter() || IsExit() ||
184             IsNextIteration());
185   }
IsHostSend()186   bool IsHostSend() const { return class_ == NC_HOST_SEND; }
IsHostRecv()187   bool IsHostRecv() const { return class_ == NC_HOST_RECV; }
IsScopedAllocator()188   bool IsScopedAllocator() const { return class_ == NC_SCOPED_ALLOCATOR; }
IsCollective()189   bool IsCollective() const { return class_ == NC_COLLECTIVE; }
190 
IsMetadata()191   bool IsMetadata() const { return class_ == NC_METADATA; }
IsFakeParam()192   bool IsFakeParam() const { return class_ == NC_FAKE_PARAM; }
IsPartitionedCall()193   bool IsPartitionedCall() const { return class_ == NC_PARTITIONED_CALL; }
194 
195   // Returns true if this node is any kind of function call node.
196   //
197   // NOTE: "function call nodes" include partitioned call ops, symbolic gradient
198   // ops, and ops whose type_string is the name of a function ("function ops").
IsFunctionCall()199   bool IsFunctionCall() const {
200     return class_ == NC_PARTITIONED_CALL || class_ == NC_FUNCTION_OP ||
201            class_ == NC_SYMBOLIC_GRADIENT;
202   }
203 
IsIfNode()204   bool IsIfNode() const { return class_ == NC_IF; }
IsWhileNode()205   bool IsWhileNode() const { return class_ == NC_WHILE; }
IsCaseNode()206   bool IsCaseNode() const { return class_ == NC_CASE; }
207   // Is this node a function input
IsArg()208   bool IsArg() const { return class_ == NC_ARG; }
209   // Is this node a function output
IsRetval()210   bool IsRetval() const { return class_ == NC_RETVAL; }
211 
IsDistributedCommunication()212   bool IsDistributedCommunication() const {
213     return op_def().is_distributed_communication();
214   }
215 
216   template <typename T>
AddAttr(const std::string & name,const T & val)217   void AddAttr(const std::string& name, const T& val) {
218     SetAttrValue(val, AddAttrHelper(name));
219     UpdateProperties();
220   }
221 
AddAttr(const std::string & name,std::vector<string> && val)222   void AddAttr(const std::string& name, std::vector<string>&& val) {
223     MoveAttrValue(std::move(val), AddAttrHelper(name));
224     UpdateProperties();
225   }
226 
227   void ClearAttr(const std::string& name);
228 
229   // Returns into '*e' the edge connecting to the 'idx' input of this Node.
230   Status input_edge(int idx, const Edge** e) const;
231 
232   // Returns into '*edges' the input data edges of this Node, indexed by input
233   // number. Does not return control edges.
234   Status input_edges(std::vector<const Edge*>* edges) const;
235 
236   // Returns into '*n' the node that has an output connected to the
237   // 'idx' input of this Node.
238   Status input_node(int idx, const Node** n) const;
239   Status input_node(int idx, Node** n) const;
240 
241   // Returns into '*t' the idx-th input tensor of this node, represented as the
242   // output tensor of input_node(idx).
243   Status input_tensor(int idx, OutputTensor* t) const;
244 
while_ctx()245   WhileContext* while_ctx() const { return while_ctx_; }
set_while_ctx(WhileContext * while_ctx)246   void set_while_ctx(WhileContext* while_ctx) {
247     DCHECK(IsExit());
248     DCHECK(while_ctx_ == nullptr);
249     while_ctx_ = while_ctx;
250   }
251 
properties()252   std::shared_ptr<NodeProperties> properties() const { return props_; }
253 
254   // Sets the stack trace for the node. Assumes that getting and setting the
255   // stack trace for a given node will not race.
SetStackTrace(const std::shared_ptr<AbstractStackTrace> & stack_trace)256   void SetStackTrace(const std::shared_ptr<AbstractStackTrace>& stack_trace) {
257     stack_trace_ = stack_trace;
258   }
259 
260   // Get the stack trace for when the node was instantiated.
GetStackTrace()261   const std::shared_ptr<AbstractStackTrace>& GetStackTrace() const {
262     return stack_trace_;
263   }
264 
265   // Called after an attr has changed. Decides whether we need to update some
266   // property of the node (stored in props_).
267   void UpdateProperties();
268 
269   // Erases type information from the node.
270   void ClearTypeInfo();
271 
272   // Called after an incident non-control edge has changed. Does nothing if not
273   // all input edges are defined.
274   void RunForwardTypeInference();
275 
276  private:
277   // TODO(mdan): Drop this.
278   friend class Graph;
279   Node();
280 
281   // Stack trace for the user code for node instantiation. Can be shared across
282   // multiple nodes (e.g. when inlining).
283   std::shared_ptr<AbstractStackTrace> stack_trace_;
284 
285   // Releases memory from props_, in addition to restoring *this to its
286   // uninitialized state.
287   void Clear();
288 
289   // Make a copy of the Node's props_ if props_ is shared with
290   // other nodes. This must be called before mutating properties,
291   // e.g. in AddAttr.
292   void MaybeCopyOnWrite();
293 
294   AttrValue* AddAttrHelper(const std::string& name);
295 
296   // A set of mutually exclusive classes for different kinds of nodes,
297   // class_ is initialized in the Node::Initialize routine based on the
298   // node's type_string().
299   enum NodeClass {
300     NC_UNINITIALIZED,
301     NC_SWITCH,
302     NC_MERGE,
303     NC_ENTER,
304     NC_EXIT,
305     NC_NEXT_ITERATION,
306     NC_LOOP_COND,
307     NC_CONTROL_TRIGGER,
308     NC_SEND,
309     NC_HOST_SEND,
310     NC_RECV,
311     NC_HOST_RECV,
312     NC_CONSTANT,
313     NC_VARIABLE,
314     NC_IDENTITY,
315     NC_GET_SESSION_HANDLE,
316     NC_GET_SESSION_TENSOR,
317     NC_DELETE_SESSION_TENSOR,
318     NC_METADATA,
319     NC_SCOPED_ALLOCATOR,
320     NC_COLLECTIVE,
321     NC_FAKE_PARAM,
322     NC_PARTITIONED_CALL,
323     NC_FUNCTION_OP,
324     NC_SYMBOLIC_GRADIENT,
325     NC_IF,
326     NC_WHILE,
327     NC_CASE,
328     NC_ARG,
329     NC_RETVAL,
330     NC_OTHER  // Not a special kind of node
331   };
332 
333   void Initialize(int id, int cost_id, std::shared_ptr<NodeProperties> props,
334                   NodeClass node_class);
335 
336   static NodeClass GetNodeClassForOp(const std::string& ts);
337 
338   int id_;       // -1 until Initialize() is called
339   int cost_id_;  // -1 if there is no corresponding cost accounting node
340   NodeClass class_;
341 
342   EdgeSet in_edges_;
343   EdgeSet out_edges_;
344 
345   // NOTE(skyewm): inheriting from core::RefCounted may have a slight
346   // performance benefit over using shared_ptr, at the cost of manual ref
347   // counting
348   std::shared_ptr<NodeProperties> props_;
349 
350   // Index within Graph::device_names_ of the name of device assigned
351   // to perform this computation.
352   int assigned_device_name_index_;
353 
354   // A back-pointer to the Graph that owns this node.  Currently, this exists
355   // solely to allow Node::[set_]assigned_device_name() to work. However, if all
356   // callers of Node::[set_]assigned_device_name() are modified to use the
357   // equivalent methods defined directly on Graph, then we can remove this
358   // field and reclaim that memory.
359   Graph* graph_;
360 
361   // Set if this is an exit node of a while loop with an associated
362   // WhileContext. Otherwise null. (This is only set for exit nodes because
363   // they're the first nodes of a loop encountered while creating the gradient
364   // graph. Exit nodes that are part of while loop gradient graphs will not have
365   // this set.)
366   WhileContext* while_ctx_;
367 
368   TF_DISALLOW_COPY_AND_ASSIGN(Node);
369 };
370 
371 // Stores debug information associated with the Node.
372 struct NodeDebugInfo {
373   const std::string name;
374   std::vector<string> original_node_names;
375   std::vector<string> original_func_names;
376 
377   NodeDebugInfo(const Node& n);
378   NodeDebugInfo(const NodeDef& ndef);
379   NodeDebugInfo(StringPiece node_name, bool has_experimental_debug_info,
380                 const NodeDef_ExperimentalDebugInfo& experimental_debug_info);
381 };
382 
383 // Represents an input of a node, i.e., the `index`-th input to `node`.
384 struct InputTensor {
385   Node* node;
386   int index;
387 
InputTensorInputTensor388   InputTensor(Node* n, int i) : node(n), index(i) {}
InputTensorInputTensor389   InputTensor() : node(nullptr), index(0) {}
390 
391   // Returns true if this InputTensor is identical to 'other'. Nodes are
392   // compared using pointer equality.
393   bool operator==(const InputTensor& other) const;
394 
395   // A hash function for InputTensors. Nodes are hashed based on their pointer
396   // value.
397   struct Hash {
398     uint64 operator()(InputTensor const& s) const;
399   };
400 };
401 
402 // Represents an output of a node, i.e., the `index`-th output of `node`. Note
403 // that a single `OutputTensor` can correspond to multiple `Edge`s if the output
404 // is consumed by multiple destination nodes.
405 struct OutputTensor {
406   Node* node;
407   int index;
408 
OutputTensorOutputTensor409   OutputTensor(Node* n, int i) : node(n), index(i) {}
OutputTensorOutputTensor410   OutputTensor() : node(nullptr), index(0) {}
411 
412   // Returns true if this OutputTensor is identical to 'other'. Nodes are
413   // compared using pointer equality.
414   bool operator==(const OutputTensor& other) const;
415 
416   // A hash function for OutputTensors. Nodes are hashed based on their pointer
417   // value.
418   struct Hash {
419     uint64 operator()(OutputTensor const& s) const;
420   };
421 };
422 
423 class Edge {
424  public:
src()425   Node* src() const { return src_; }
dst()426   Node* dst() const { return dst_; }
id()427   int id() const { return id_; }
428 
429   // Return the index of the source output that produces the data
430   // carried by this edge.  The special value kControlSlot is used
431   // for control dependencies.
src_output()432   int src_output() const { return src_output_; }
433 
434   // Return the index of the destination input that consumes the data
435   // carried by this edge.  The special value kControlSlot is used
436   // for control dependencies.
dst_input()437   int dst_input() const { return dst_input_; }
438 
439   // Return true iff this is an edge that indicates a control-flow
440   // (as opposed to a data-flow) dependency.
441   bool IsControlEdge() const;
442 
443   std::string DebugString() const;
444 
445  private:
Edge()446   Edge() {}
447 
448   friend class EdgeSetTest;
449   friend class Graph;
450   Node* src_;
451   Node* dst_;
452   int id_;
453   int src_output_;
454   int dst_input_;
455 };
456 
457 // Allows for iteration of the edges of a Graph, by iterating the underlying
458 // Graph.edges_ vector while skipping over null entries.
459 class GraphEdgesIterable {
460  private:
461   const std::vector<Edge*>& edges_;
462 
463  public:
GraphEdgesIterable(const std::vector<Edge * > & edges)464   explicit GraphEdgesIterable(const std::vector<Edge*>& edges)
465       : edges_(edges) {}
466 
467   typedef Edge* value_type;
468 
469   class const_iterator {
470    private:
471     // The underlying iterator.
472     std::vector<value_type>::const_iterator iter_;
473 
474     // The end of the underlying iterator.
475     std::vector<value_type>::const_iterator end_;
476 
477     // Advances iter_ until it reaches a non-null item, or reaches the end.
apply_filter()478     void apply_filter() {
479       while (iter_ != end_ && *iter_ == nullptr) {
480         ++iter_;
481       }
482     }
483 
484    public:
const_iterator(std::vector<value_type>::const_iterator iter,std::vector<value_type>::const_iterator end)485     const_iterator(std::vector<value_type>::const_iterator iter,
486                    std::vector<value_type>::const_iterator end)
487         : iter_(iter), end_(end) {
488       apply_filter();
489     }
490 
491     bool operator==(const const_iterator& other) const {
492       return iter_ == other.iter_;
493     }
494 
495     bool operator!=(const const_iterator& other) const {
496       return iter_ != other.iter_;
497     }
498 
499     // This is the prefix increment operator (++x), which is the operator
500     // used by C++ range iteration (for (x : y) ...).  We intentionally do not
501     // provide a postfix increment operator.
502     const_iterator& operator++() {
503       ++iter_;
504       apply_filter();
505       return *this;
506     }
507 
508     value_type operator*() { return *iter_; }
509   };
510 
begin()511   const_iterator begin() {
512     return const_iterator(edges_.begin(), edges_.end());
513   }
end()514   const_iterator end() { return const_iterator(edges_.end(), edges_.end()); }
515 };
516 
517 // Thread compatible but not thread safe.
518 class Graph {
519  public:
520   // Constructs a graph with a single SOURCE (always id kSourceId) and a
521   // single SINK (always id kSinkId) node, and an edge from SOURCE->SINK.
522   //
523   // The graph can hold ops found in the registry. `ops`s lifetime must be at
524   // least that of the constructed graph's.
525   explicit Graph(const OpRegistryInterface* ops);
526 
527   // Constructs a graph with a single SOURCE (always id kSourceId) and a
528   // single SINK (always id kSinkId) node, and an edge from SOURCE->SINK.
529   //
530   // The graph can hold ops found in `flib_def`. Unlike the constructor taking
531   // an OpRegistryInterface, this constructor copies the function definitions in
532   // `flib_def` so its lifetime may be shorter than that of the graph's. The
533   // OpRegistryInterface backing `flib_def` must still have the lifetime of the
534   // graph though.
535   explicit Graph(const FunctionLibraryDefinition& flib_def);
536 
537   ~Graph();
538 
539   // Clone the current graph into a new one.
540   std::unique_ptr<Graph> Clone();
541 
542   static const int kControlSlot;
543 
544   // The GraphDef version range of this graph (see graph.proto).
545   const VersionDef& versions() const;
546   void set_versions(const VersionDef& versions);
547 
548   // Adds a new node to this graph, and returns it. Infers the Op and
549   // input/output types for the node. *this owns the returned instance.
550   // Returns nullptr and sets *status on error.
551   Node* AddNode(NodeDef node_def, Status* status);
552 
553   // Same as above, but using StatusOr. This method is always preferred.
554   StatusOr<Node*> AddNode(NodeDef node_def);
555 
556   // Copies *node, which may belong to another graph, to a new node,
557   // which is returned.  Does not copy any edges.  *this owns the
558   // returned instance.
559   Node* CopyNode(const Node* node);
560 
561   // Removes a node from this graph, including all edges from or to it.
562   // *node should not be accessed after calling this function.
563   // REQUIRES: node->IsOp()
564   void RemoveNode(Node* node);
565 
566   void Copy(const Graph& src);
567 
568   // Removes all nodes from this graph, including all edges from or to them.
569   // No Node* references to the Graph are valid post.
570   void Clear();
571 
572   // Adds an edge that connects the xth output of `source` to the yth input of
573   // `dest` and returns it. Does not update dest's NodeDef.
574   const Edge* AddEdge(Node* source, int x, Node* dest, int y);
575 
576   // Adds a control edge (no data flows along this edge) that connects `source`
577   // to `dest`. If `dest`s NodeDef is missing the corresponding control input,
578   // adds the control input.
579   //
580   // If such a control edge already exists and `allow_duplicates` is false, no
581   // edge is added and the function returns nullptr. Otherwise the edge is
582   // unconditionally created and returned. The NodeDef is not updated if
583   // `allow_duplicates` is true.
584   // TODO(skyewm): // TODO(skyewm): allow_duplicates is needed only by
585   // graph_partition.cc. Figure out if we can do away with it.
586   const Edge* AddControlEdge(Node* source, Node* dest,
587                              bool allow_duplicates = false);
588 
589   // Removes edge from the graph. Does not update the destination node's
590   // NodeDef.
591   // REQUIRES: The edge must exist.
592   void RemoveEdge(const Edge* edge);
593 
594   // Removes control edge `edge` from the graph. Note that this also updates
595   // the corresponding NodeDef to reflect the change.
596   // REQUIRES: The control edge must exist.
597   void RemoveControlEdge(const Edge* e);
598 
599   // Updates the input to a node.  The existing edge to `dst` is removed and an
600   // edge from `new_src` to `dst` is created. The NodeDef associated with `dst`
601   // is also updated.
602   Status UpdateEdge(Node* new_src, int new_src_index, Node* dst, int dst_index);
603 
604   // Like AddEdge but updates dst's NodeDef. Used to add an input edge to a
605   // "While" op during gradient construction, see AddInputWhileHack in
606   // python_api.h for more details.
607   Status AddWhileInputHack(Node* new_src, int new_src_index, Node* dst);
608 
609   // Adds the function and gradient definitions in `fdef_lib` to this graph's op
610   // registry. Ignores duplicate functions, and returns a bad status if an
611   // imported function differs from an existing function or op with the same
612   // name.
613   Status AddFunctionLibrary(const FunctionDefLibrary& fdef_lib);
614 
615   // The number of live nodes in the graph.
616   //
617   // Because nodes can be removed from the graph, num_nodes() is often
618   // smaller than num_node_ids(). If one needs to create an array of
619   // nodes indexed by node ids, num_node_ids() should be used as the
620   // array's size.
num_nodes()621   int num_nodes() const { return num_nodes_; }
622 
623   // The number of live nodes in the graph, excluding the Source and Sink nodes.
num_op_nodes()624   int num_op_nodes() const {
625     DCHECK_GE(num_nodes_, 2);
626     return num_nodes_ - 2;
627   }
628 
629   // The number of live edges in the graph.
630   //
631   // Because edges can be removed from the graph, num_edges() is often
632   // smaller than num_edge_ids(). If one needs to create an array of
633   // edges indexed by edge ids, num_edge_ids() should be used as the
634   // array's size.
num_edges()635   int num_edges() const { return num_edges_; }
636 
637   // Serialize the nodes starting at `from_node_id` to a GraphDef.
638   void ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const;
639 
640   // Serialize to a GraphDef.
641   void ToGraphDef(GraphDef* graph_def) const;
642 
643   // This version can be called from debugger to inspect the graph content.
644   // Use the previous version outside debug context for efficiency reasons.
645   //
646   // Note: We do not expose a DebugString() API, since GraphDef.DebugString() is
647   // not defined in some TensorFlow builds.
648   GraphDef ToGraphDefDebug() const;
649 
650   // Generate new node name with the specified prefix that is unique
651   // across this graph.
652   std::string NewName(StringPiece prefix);
653 
654   // Access to the list of all nodes.  Example usage:
655   //   for (Node* node : graph.nodes()) { ... }
656   gtl::iterator_range<NodeIter> nodes() const;
657 
658   // Access to the list of all nodes, excluding the Source and Sink nodes.
659   gtl::iterator_range<NodeIter> op_nodes() const;
660 
661   // Returns one more than the maximum id assigned to any node.
num_node_ids()662   int num_node_ids() const { return nodes_.size(); }
663 
664   // Returns the node associated with an id, or nullptr if no node
665   // with that id (the node with that id was removed and the id has
666   // not yet been re-used). *this owns the returned instance.
667   // REQUIRES: 0 <= id < num_node_ids().
FindNodeId(int id)668   Node* FindNodeId(int id) const { return nodes_[id]; }
669 
670   // Returns one more than the maximum id assigned to any edge.
num_edge_ids()671   int num_edge_ids() const { return edges_.size(); }
672 
673   // Returns the Edge associated with an id, or nullptr if no edge
674   // with that id (the edge with that id was removed and the id has
675   // not yet been re-used). *this owns the returned instance.
676   // REQUIRES: 0 <= id < num_edge_ids().
FindEdgeId(int id)677   const Edge* FindEdgeId(int id) const { return edges_[id]; }
678 
679   // Access to the set of all edges.  Example usage:
680   //   for (const Edge* e : graph.edges()) { ... }
edges()681   GraphEdgesIterable edges() const { return GraphEdgesIterable(edges_); }
682 
683   // The pre-defined nodes.
684   enum { kSourceId = 0, kSinkId = 1 };
source_node()685   Node* source_node() const { return FindNodeId(kSourceId); }
sink_node()686   Node* sink_node() const { return FindNodeId(kSinkId); }
687 
op_registry()688   const OpRegistryInterface* op_registry() const { return &ops_; }
flib_def()689   const FunctionLibraryDefinition& flib_def() const { return ops_; }
690 
691   // TODO(mdan): This is only used by control_flow_deps_o_chains. Remove?
mutable_flib_def()692   FunctionLibraryDefinition* mutable_flib_def() { return &ops_; }
693 
CheckDeviceNameIndex(int index)694   void CheckDeviceNameIndex(int index) {
695     DCHECK_GE(index, 0);
696     DCHECK_LT(index, static_cast<int>(device_names_.size()));
697   }
698 
699   int InternDeviceName(const std::string& device_name);
700 
get_assigned_device_name(const Node & node)701   const std::string& get_assigned_device_name(const Node& node) const {
702     return device_names_[node.assigned_device_name_index()];
703   }
704 
set_assigned_device_name_index(Node * node,int device_name_index)705   void set_assigned_device_name_index(Node* node, int device_name_index) {
706     CheckDeviceNameIndex(device_name_index);
707     node->assigned_device_name_index_ = device_name_index;
708   }
709 
set_assigned_device_name(Node * node,const std::string & device_name)710   void set_assigned_device_name(Node* node, const std::string& device_name) {
711     node->assigned_device_name_index_ = InternDeviceName(device_name);
712   }
713 
714   // Returns OK if `node` is non-null and belongs to this graph
715   Status IsValidNode(const Node* node) const;
716 
717   // Returns OK if IsValidNode(`node`) and `idx` is a valid output.  Does not
718   // accept control outputs.
719   Status IsValidOutputTensor(const Node* node, int idx) const;
720 
721   // Returns OK if IsValidNode(`node`) and `idx` a valid input.  Does not accept
722   // control inputs.
723   Status IsValidInputTensor(const Node* node, int idx) const;
724 
725   // Create and return a new WhileContext owned by this graph. This is called
726   // when a new while loop is created. `frame_name` must be unique among
727   // WhileContexts in this graph.
728   Status AddWhileContext(StringPiece frame_name, std::vector<Node*> enter_nodes,
729                          std::vector<Node*> exit_nodes,
730                          OutputTensor cond_output,
731                          std::vector<OutputTensor> body_inputs,
732                          std::vector<OutputTensor> body_outputs,
733                          WhileContext** result);
734 
735   // Builds a node name to node pointer index for all nodes in the graph.
736   std::unordered_map<string, Node*> BuildNodeNameIndex() const;
737 
GetConstArgIndicesCache()738   absl::optional<std::vector<bool>>& GetConstArgIndicesCache() const {
739     return const_arg_indices_cache_;
740   }
741 
742   // TODO(kkb): Add to the constructor when it becomes managable.
743   // Sets the graph construction context.
SetConstructionContext(ConstructionContext construction_context)744   void SetConstructionContext(ConstructionContext construction_context) {
745     construction_context_ = construction_context;
746   }
747 
748   // TODO(kkb): Rename to `GetConstructionContext` once we're comfortable
749   // making this stable and make it available widely.
750   // Returns the graph construction context. It's `kUnknown` if not set.
GetConstructionContextInternal()751   ConstructionContext GetConstructionContextInternal() const {
752     return construction_context_;
753   }
754 
755   // TODO(josh11b): uint64 hash() const;
756 
757  private:
758   // If cost_node is non-null, then cost accounting (in CostModel)
759   // will be associated with that node rather than the new one being
760   // created.
761   //
762   // Ownership of the returned Node is not transferred to caller.
763   Node* AllocateNode(std::shared_ptr<NodeProperties> props,
764                      const Node* cost_node, Node::NodeClass node_class);
765   void ReleaseNode(Node* node);
766   // Insert edge in free_edges_ for possible reuse.
767   void RecycleEdge(const Edge* edge);
768   // Registry of all known ops, including functions.
769   FunctionLibraryDefinition ops_;
770 
771   // GraphDef versions
772   const std::unique_ptr<VersionDef> versions_;
773 
774   // Allocator which will give us good locality.
775   core::Arena arena_;
776 
777   // Map from node ids to allocated nodes.  nodes_[id] may be nullptr if
778   // the node with that id was removed from the graph.
779   std::vector<Node*> nodes_;
780 
781   // Number of nodes alive.
782   int64_t num_nodes_ = 0;
783 
784   // Map from edge ids to allocated edges.  edges_[id] may be nullptr if
785   // the edge with that id was removed from the graph.
786   std::vector<Edge*> edges_;
787 
788   // The number of entries in edges_ that are not nullptr.
789   int num_edges_ = 0;
790 
791   // Allocated but free nodes and edges.
792   std::vector<Node*> free_nodes_;
793   std::vector<Edge*> free_edges_;
794 
795   // For generating unique names.
796   int name_counter_ = 0;
797 
798   // In most graphs, the number of unique values used for the
799   // Node::assigned_device_name() property is quite small.  If the graph is
800   // large, then this duplication of values can consume a significant amount of
801   // memory.  Instead, we represent the same information using an interning
802   // table, which consists of a vector of unique strings (device_names_), as
803   // well a map (device_names_map_) from unique strings to indices within the
804   // unique string table.
805   //
806   // The InternDeviceName() method handles adding a new entry into the table,
807   // or locating the index of an existing entry.
808   //
809   // The fact that Node::assigned_device_name() is implemented using an
810   // interning table is intentionally public.  This allows algorithms that
811   // frequently access this field to do so efficiently, especially for the case
812   // where the assigned_device_name of one Node is copied directly from that
813   // of another Node.
814 
815   // A table of the unique assigned device names.  Indices do NOT correspond
816   // to node IDs.  Index 0 is always the empty string.
817   std::vector<string> device_names_;
818 
819   // Maps unique device names to indices within device_names_[i].
820   std::unordered_map<string, int> device_names_map_;
821 
822   // All the while contexts owned by this graph, keyed by frame name,
823   // corresponding to all the while loops contained in this graph (including
824   // nested loops). The stored contexts are usually accessed via
825   // AddWhileContext() or Node::while_ctx(), but this manages the lifetime.
826   std::map<string, WhileContext> while_ctxs_;
827 
828   // Cache of the indices of the arguments which need to be constant for the XLA
829   // compilation.
830   mutable absl::optional<std::vector<bool>> const_arg_indices_cache_;
831 
832   // Indicates the context that this Graph instance is constructed.
833   ConstructionContext construction_context_ = ConstructionContext::kNotTracked;
834 
835   TF_DISALLOW_COPY_AND_ASSIGN(Graph);
836 };
837 
838 // TODO(josh11b): We may want to support keeping an index on various
839 // node/edge attributes in a graph, particularly node names.
840 
841 // Helper routines
842 
IsSource(const Node * node)843 inline bool IsSource(const Node* node) { return node->IsSource(); }
IsSink(const Node * node)844 inline bool IsSink(const Node* node) { return node->IsSink(); }
IsSwitch(const Node * node)845 inline bool IsSwitch(const Node* node) { return node->IsSwitch(); }
IsMerge(const Node * node)846 inline bool IsMerge(const Node* node) { return node->IsMerge(); }
IsEnter(const Node * node)847 inline bool IsEnter(const Node* node) { return node->IsEnter(); }
IsExit(const Node * node)848 inline bool IsExit(const Node* node) { return node->IsExit(); }
IsNextIteration(const Node * n)849 inline bool IsNextIteration(const Node* n) { return n->IsNextIteration(); }
IsLoopCond(const Node * node)850 inline bool IsLoopCond(const Node* node) { return node->IsLoopCond(); }
IsControlTrigger(const Node * n)851 inline bool IsControlTrigger(const Node* n) { return n->IsControlTrigger(); }
IsSend(const Node * node)852 inline bool IsSend(const Node* node) { return node->IsSend(); }
IsRecv(const Node * node)853 inline bool IsRecv(const Node* node) { return node->IsRecv(); }
IsHostSend(const Node * node)854 inline bool IsHostSend(const Node* node) { return node->IsHostSend(); }
IsHostRecv(const Node * node)855 inline bool IsHostRecv(const Node* node) { return node->IsHostRecv(); }
856 
857 // True for Nodes that mediate the transfer of values between processes.
IsTransferNode(const Node * n)858 inline bool IsTransferNode(const Node* n) { return IsSend(n) || IsRecv(n); }
859 
IsConstant(const Node * node)860 inline bool IsConstant(const Node* node) { return node->IsConstant(); }
IsVariable(const Node * node)861 inline bool IsVariable(const Node* node) { return node->IsVariable(); }
IsIdentity(const Node * node)862 inline bool IsIdentity(const Node* node) { return node->IsIdentity(); }
863 
864 // Returns true iff 'n' is a control flow node.
IsControlFlow(const Node * n)865 inline bool IsControlFlow(const Node* n) { return n->IsControlFlow(); }
866 
867 // Returns true if the node only depends on its input's metadata
868 // (shape).  Specifically, returns true for "Size", "Shape" and "Rank" ops.
IsMetadata(const Node * n)869 inline bool IsMetadata(const Node* n) { return n->IsMetadata(); }
870 
IsScopedAllocator(const Node * n)871 inline bool IsScopedAllocator(const Node* n) { return n->IsScopedAllocator(); }
872 
IsHostMemoryPreserving(const Node * node)873 inline bool IsHostMemoryPreserving(const Node* node) {
874   return IsIdentity(node) || IsControlFlow(node);
875 }
876 
IsDistributedCommunication(const Node * n)877 inline bool IsDistributedCommunication(const Node* n) {
878   return n->IsDistributedCommunication();
879 }
880 
881 // NOTE: We declare Reference type of NodeIter and NeighborIter as Node* (see
882 // https://en.cppreference.com/w/cpp/iterator/iterator).
883 
884 // Iterator for stepping through the nodes of a graph.
885 class NodeIter
886     : public std::iterator<std::forward_iterator_tag, Node, std::ptrdiff_t,
887                            /*Pointer*/ Node*, /*Reference*/ Node*> {
888  public:
889   NodeIter(const Graph* graph, int id);
890   bool operator==(const NodeIter& rhs) const;
891   bool operator!=(const NodeIter& rhs) const;
892   void operator++();
893   reference operator*() const;
894   pointer operator->() const;
895 
896  private:
897   // Invariant: id_ == graph_->num_node_ids() || graph_->FindId(id_) != nullptr
898   const Graph* graph_;
899   int id_;
900 };
901 
902 // Iterator for stepping through the neighbors of a node.
903 class NeighborIter
904     : public std::iterator<std::forward_iterator_tag, Node, std::ptrdiff_t,
905                            /*Pointer*/ Node*, /*Reference*/ Node*> {
906  public:
907   NeighborIter(EdgeSet::const_iterator iter, bool incoming);
908   bool operator==(const NeighborIter& rhs) const;
909   bool operator!=(const NeighborIter& rhs) const;
910   void operator++();
911   reference operator*() const;
912   pointer operator->() const;
913 
914  private:
915   EdgeSet::const_iterator iter_;
916   bool incoming_;
917 };
918 
919 // IMPLEMENTATION DETAILS, PLEASE IGNORE
920 
NodeIter(const Graph * graph,int id)921 inline NodeIter::NodeIter(const Graph* graph, int id)
922     : graph_(graph), id_(id) {}
923 
924 inline bool NodeIter::operator==(const NodeIter& rhs) const {
925   DCHECK(graph_ == rhs.graph_);
926   return id_ == rhs.id_;
927 }
928 
929 inline bool NodeIter::operator!=(const NodeIter& rhs) const {
930   return !(*this == rhs);
931 }
932 
933 inline void NodeIter::operator++() {
934   while (1) {
935     DCHECK_LE(id_, graph_->num_node_ids());
936     ++id_;
937     if (id_ >= graph_->num_node_ids() || graph_->FindNodeId(id_) != nullptr) {
938       return;
939     }
940   }
941 }
942 
943 inline Node* NodeIter::operator*() const { return graph_->FindNodeId(id_); }
944 
945 inline Node* NodeIter::operator->() const { return graph_->FindNodeId(id_); }
946 
NeighborIter(EdgeSet::const_iterator iter,bool incoming)947 inline NeighborIter::NeighborIter(EdgeSet::const_iterator iter, bool incoming)
948     : iter_(iter), incoming_(incoming) {}
949 
950 inline bool NeighborIter::operator==(const NeighborIter& rhs) const {
951   return iter_ == rhs.iter_ && incoming_ == rhs.incoming_;
952 }
953 
954 inline bool NeighborIter::operator!=(const NeighborIter& rhs) const {
955   return !(*this == rhs);
956 }
957 
958 inline void NeighborIter::operator++() { ++iter_; }
959 
960 inline Node* NeighborIter::operator*() const {
961   const Edge* e = *iter_;
962   return incoming_ ? e->src() : e->dst();
963 }
964 
965 inline Node* NeighborIter::operator->() const {
966   const Edge* e = *iter_;
967   return incoming_ ? e->src() : e->dst();
968 }
969 
IsControlEdge()970 inline bool Edge::IsControlEdge() const {
971   // Note that if either src_output_ or dst_input_ is kControlSlot,
972   // so is the other one (AddEdge checks this).
973   return src_output_ == Graph::kControlSlot;
974 }
975 
nodes()976 inline gtl::iterator_range<NodeIter> Graph::nodes() const {
977   // Note that NodeId 0 is always valid since we don't let the source
978   // node be removed from the graph.
979   return gtl::make_range(NodeIter(this, 0), NodeIter(this, num_node_ids()));
980 }
981 
op_nodes()982 inline gtl::iterator_range<NodeIter> Graph::op_nodes() const {
983   // Note that NodeId 0 is always valid since we don't let the source
984   // node be removed from the graph.
985   //
986   // The current implementation of Graph maintains the invariant that the
987   // first two nodes are the source and sink nodes, and all other nodes are op
988   // nodes. This method (op_nodes()) relies on this invariant.
989   NodeIter begin(this, 0);
990   NodeIter end(this, num_node_ids());
991   if (begin != end) {
992     ++begin;
993   }
994   if (begin != end) {
995     ++begin;
996   }
997   return gtl::make_range(begin, end);
998 }
999 
set_assigned_device_name_index(int index)1000 inline void Node::set_assigned_device_name_index(int index) {
1001   graph_->CheckDeviceNameIndex(index);
1002   assigned_device_name_index_ = index;
1003 }
1004 
set_assigned_device_name(const std::string & device_name)1005 inline void Node::set_assigned_device_name(const std::string& device_name) {
1006   graph_->set_assigned_device_name(this, device_name);
1007 }
1008 
assigned_device_name()1009 inline const std::string& Node::assigned_device_name() const {
1010   return graph_->get_assigned_device_name(*this);
1011 }
1012 
1013 }  // namespace tensorflow
1014 
1015 #endif  // TENSORFLOW_CORE_GRAPH_GRAPH_H_
1016