1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // A Graph describes a set of computations that are to be
17 // performed, as well as the dependencies between those
18 // computations. The basic model is a DAG (directed acyclic graph) with
19 // * internal nodes representing computational operations to be performed;
20 // * edges represent dependencies, indicating the target may only be
21 // executed once the source has completed; and
22 // * predefined "source" (start) and "sink" (finish) nodes -- the source
23 // should be the only node that doesn't depend on anything, and the sink
24 // should be the only node that nothing depends on.
25 //
26 // Note: Node ids are intended to be relatively dense in the
27 // 0..max_id range, but there may be gaps since ids won't be reused.
28 //
29 // Note: Some dependencies between operations are due to one operation
30 // consuming the output of another. In fact operations can produce
31 // multiple outputs and consume multiple inputs, and some
32 // optimizations will care about which specific outputs are connected
33 // to which specific inputs. We therefore represent data dependency
34 // between output O of layer A and input I of layer B using
35 // "input index" and "output index" labels per edge.
36
37 #ifndef TENSORFLOW_CORE_GRAPH_GRAPH_H_
38 #define TENSORFLOW_CORE_GRAPH_GRAPH_H_
39
40 #include <functional>
41 #include <string>
42 #include <vector>
43
44 #include "tensorflow/core/framework/function.h"
45 #include "tensorflow/core/framework/node_def.pb.h"
46 #include "tensorflow/core/framework/op.h"
47 #include "tensorflow/core/framework/types.h"
48 #include "tensorflow/core/graph/edgeset.h"
49 #include "tensorflow/core/lib/core/arena.h"
50 #include "tensorflow/core/lib/core/refcount.h"
51 #include "tensorflow/core/lib/core/status.h"
52 #include "tensorflow/core/lib/gtl/iterator_range.h"
53 #include "tensorflow/core/platform/logging.h"
54 #include "tensorflow/core/platform/macros.h"
55 #include "tensorflow/core/platform/types.h"
56
57 namespace tensorflow {
58
59 class Edge;
60 class EdgeSetTest;
61 class Graph;
62 class GraphDef;
63 class Node;
64 struct OutputTensor;
65 class VersionDef;
66 class WhileContext;
67
68 class NeighborIter; // Declared below
69 class NodeIter; // Declared below
70 struct NodeProperties; // Defined in .cc
71
72 class Node {
73 public:
74 string DebugString() const;
id()75 int id() const { return id_; }
cost_id()76 int cost_id() const { return cost_id_; }
77 const string& name() const;
78 void set_name(string name);
79 const string& type_string() const;
80
81 // def() provides the NodeDef the user supplied, but the specifics
82 // of this Node may have changed due to placement, optimization, etc.
83 // In particular:
84 // * def().name() will match name();
85 // * def().op() will match type_string() and op_def().name();
86 // * def().input() is not reliable, use "in_edges()" below instead;
87 // * def().device() is the "user's requested device" and may not match
88 // the actual assigned device, see assigned_device_name() below;
89 // * def().attr() is authoritative.
90 // TODO(irving): Replace with NodeInfo.
91 const NodeDef& def() const;
92 const OpDef& op_def() const;
93
94 // input and output types
95 int32 num_inputs() const;
96 DataType input_type(int32 i) const;
97 const DataTypeVector& input_types() const;
98
99 int32 num_outputs() const;
100 DataType output_type(int32 o) const;
101 const DataTypeVector& output_types() const;
102
103 // The device requested by the user. For the actual assigned device,
104 // use assigned_device_name() below.
105 const string& requested_device() const;
106
107 // This changes the user requested device but not necessarily the device that
108 // on which the operation will run.
109 void set_requested_device(const string& device);
110
111 // This gives the device the runtime has assigned this node to. If
112 // you want the device the user requested, use def().device() instead.
113 // TODO(josh11b): Validate that the assigned_device, if not empty:
114 // fully specifies a device, and satisfies def().device().
115 // TODO(josh11b): Move assigned_device_name outside of Node into a
116 // NodeId->DeviceName map.
117 const string& assigned_device_name() const;
118 void set_assigned_device_name(const string& device_name);
has_assigned_device_name()119 bool has_assigned_device_name() const {
120 return assigned_device_name_index_ > 0;
121 }
assigned_device_name_index()122 int assigned_device_name_index() const { return assigned_device_name_index_; }
123 void set_assigned_device_name_index(int index);
124
125 // Sets 'original_node_names' field of this node's DebugInfo proto to
126 // 'names'.
127 void set_original_node_names(const std::vector<string>& names);
128
129 // Read only access to attributes
130 AttrSlice attrs() const;
131
132 // Inputs requested by the NodeDef. For the actual inputs, use in_edges.
133 const protobuf::RepeatedPtrField<string>& requested_inputs() const;
134
135 // Get the neighboring nodes via edges either in or out of this node. This
136 // includes control edges.
137 gtl::iterator_range<NeighborIter> in_nodes() const;
138 gtl::iterator_range<NeighborIter> out_nodes() const;
in_edges()139 const EdgeSet& in_edges() const { return in_edges_; }
out_edges()140 const EdgeSet& out_edges() const { return out_edges_; }
141
142 // Node type helpers.
IsSource()143 bool IsSource() const { return id() == 0; }
IsSink()144 bool IsSink() const { return id() == 1; }
145 // Anything other than the special Source & Sink nodes.
IsOp()146 bool IsOp() const { return id() > 1; }
147
148 // Node class helpers
IsSwitch()149 bool IsSwitch() const { return class_ == NC_SWITCH; }
IsMerge()150 bool IsMerge() const { return class_ == NC_MERGE; }
IsEnter()151 bool IsEnter() const { return class_ == NC_ENTER; }
IsExit()152 bool IsExit() const { return class_ == NC_EXIT; }
IsNextIteration()153 bool IsNextIteration() const { return class_ == NC_NEXT_ITERATION; }
IsLoopCond()154 bool IsLoopCond() const { return class_ == NC_LOOP_COND; }
IsControlTrigger()155 bool IsControlTrigger() const { return class_ == NC_CONTROL_TRIGGER; }
IsSend()156 bool IsSend() const { return class_ == NC_SEND || class_ == NC_HOST_SEND; }
IsRecv()157 bool IsRecv() const { return class_ == NC_RECV || class_ == NC_HOST_RECV; }
IsConstant()158 bool IsConstant() const { return class_ == NC_CONSTANT; }
IsVariable()159 bool IsVariable() const { return class_ == NC_VARIABLE; }
IsIdentity()160 bool IsIdentity() const { return class_ == NC_IDENTITY; }
IsGetSessionHandle()161 bool IsGetSessionHandle() const { return class_ == NC_GET_SESSION_HANDLE; }
IsGetSessionTensor()162 bool IsGetSessionTensor() const { return class_ == NC_GET_SESSION_TENSOR; }
IsDeleteSessionTensor()163 bool IsDeleteSessionTensor() const {
164 return class_ == NC_DELETE_SESSION_TENSOR;
165 }
IsControlFlow()166 bool IsControlFlow() const {
167 return (class_ != NC_OTHER) && // Fast path
168 (IsSwitch() || IsMerge() || IsEnter() || IsExit() ||
169 IsNextIteration());
170 }
IsHostSend()171 bool IsHostSend() const { return class_ == NC_HOST_SEND; }
IsHostRecv()172 bool IsHostRecv() const { return class_ == NC_HOST_RECV; }
IsScopedAllocator()173 bool IsScopedAllocator() const { return class_ == NC_SCOPED_ALLOCATOR; }
IsCollective()174 bool IsCollective() const { return class_ == NC_COLLECTIVE; }
175
IsMetadata()176 bool IsMetadata() const { return class_ == NC_METADATA; }
IsFakeParam()177 bool IsFakeParam() const { return class_ == NC_FAKE_PARAM; }
IsPartitionedCall()178 bool IsPartitionedCall() const { return class_ == NC_PARTITIONED_CALL; }
179
180 // Returns true if this node is any kind of function call node.
181 //
182 // NOTE: "function call nodes" include partitioned call ops, symbolic gradient
183 // ops, and ops whose type_string is the name of a function ("function ops").
IsFunctionCall()184 bool IsFunctionCall() const {
185 return class_ == NC_PARTITIONED_CALL || class_ == NC_FUNCTION_OP ||
186 class_ == NC_SYMBOLIC_GRADIENT;
187 }
188
IsIfNode()189 bool IsIfNode() const { return class_ == NC_IF; }
IsWhileNode()190 bool IsWhileNode() const { return class_ == NC_WHILE; }
191 // Is this node a function input
IsArg()192 bool IsArg() const { return class_ == NC_ARG; }
193 // Is this node a function output
IsRetval()194 bool IsRetval() const { return class_ == NC_RETVAL; }
195
196 template <typename T>
AddAttr(const string & name,const T & val)197 void AddAttr(const string& name, const T& val) {
198 SetAttrValue(val, AddAttrHelper(name));
199 UpdateProperties();
200 }
201
AddAttr(const string & name,std::vector<string> && val)202 void AddAttr(const string& name, std::vector<string>&& val) {
203 MoveAttrValue(std::move(val), AddAttrHelper(name));
204 UpdateProperties();
205 }
206
207 void ClearAttr(const string& name);
208
209 // Returns into '*e' the edge connecting to the 'idx' input of this Node.
210 Status input_edge(int idx, const Edge** e) const;
211
212 // Returns into '*edges' the input data edges of this Node, indexed by input
213 // number. Does not return control edges.
214 Status input_edges(std::vector<const Edge*>* edges) const;
215
216 // Returns into '*n' the node that has an output connected to the
217 // 'idx' input of this Node.
218 Status input_node(int idx, const Node** n) const;
219 Status input_node(int idx, Node** n) const;
220
221 // Returns into '*t' the idx-th input tensor of this node, represented as the
222 // output tensor of input_node(idx).
223 Status input_tensor(int idx, OutputTensor* t) const;
224
while_ctx()225 WhileContext* while_ctx() const { return while_ctx_; }
set_while_ctx(WhileContext * while_ctx)226 void set_while_ctx(WhileContext* while_ctx) {
227 DCHECK(IsExit());
228 DCHECK(while_ctx_ == nullptr);
229 while_ctx_ = while_ctx;
230 }
231
232 private:
233 friend class Graph;
234 Node();
235
properties()236 NodeProperties* properties() const { return props_.get(); }
237
238 void Initialize(int id, int cost_id, std::shared_ptr<NodeProperties> props,
239 bool is_function_op);
240
241 // Releases memory from props_, in addition to restoring *this to its
242 // uninitialized state.
243 void Clear();
244
245 // Make a copy of the Node's props_ if props_ is shared with
246 // other nodes. This must be called before mutating properties,
247 // e.g. in AddAttr.
248 void MaybeCopyOnWrite();
249
250 // Called after an attr has changed. Decides whether we need to update some
251 // property of the node (stored in props_).
252 void UpdateProperties();
253
254 AttrValue* AddAttrHelper(const string& name);
255
256 // A set of mutually exclusive classes for different kinds of nodes,
257 // class_ is initialized in the Node::Initialize routine based on the
258 // node's type_string().
259 enum NodeClass {
260 NC_UNINITIALIZED,
261 NC_SWITCH,
262 NC_MERGE,
263 NC_ENTER,
264 NC_EXIT,
265 NC_NEXT_ITERATION,
266 NC_LOOP_COND,
267 NC_CONTROL_TRIGGER,
268 NC_SEND,
269 NC_HOST_SEND,
270 NC_RECV,
271 NC_HOST_RECV,
272 NC_CONSTANT,
273 NC_VARIABLE,
274 NC_IDENTITY,
275 NC_GET_SESSION_HANDLE,
276 NC_GET_SESSION_TENSOR,
277 NC_DELETE_SESSION_TENSOR,
278 NC_METADATA,
279 NC_SCOPED_ALLOCATOR,
280 NC_COLLECTIVE,
281 NC_FAKE_PARAM,
282 NC_PARTITIONED_CALL,
283 NC_FUNCTION_OP,
284 NC_SYMBOLIC_GRADIENT,
285 NC_IF,
286 NC_WHILE,
287 NC_ARG,
288 NC_RETVAL,
289 NC_OTHER // Not a special kind of node
290 };
291
292 static const std::unordered_map<string, NodeClass>& kNodeClassTable;
293
294 static NodeClass GetNodeClassForOp(const string& ts);
295
296 int id_; // -1 until Initialize() is called
297 int cost_id_; // -1 if there is no corresponding cost accounting node
298 NodeClass class_;
299
300 EdgeSet in_edges_;
301 EdgeSet out_edges_;
302
303 // NOTE(skyewm): inheriting from core::RefCounted may have a slight
304 // performance benefit over using shared_ptr, at the cost of manual ref
305 // counting
306 std::shared_ptr<NodeProperties> props_;
307
308 // Index within Graph::device_names_ of the name of device assigned
309 // to perform this computation.
310 int assigned_device_name_index_;
311
312 // A back-pointer to the Graph that owns this node. Currently, this exists
313 // solely to allow Node::[set_]assigned_device_name() to work. However, if all
314 // callers of Node::[set_]assigned_device_name() are modified to use the
315 // equivalent methods defined directly on Graph, then we can remove this
316 // field and reclaim that memory.
317 Graph* graph_;
318
319 // Set if this is an exit node of a while loop with an associated
320 // WhileContext. Otherwise null. (This is only set for exit nodes because
321 // they're the first nodes of a loop encountered while creating the gradient
322 // graph. Exit nodes that are part of while loop gradient graphs will not have
323 // this set.)
324 WhileContext* while_ctx_;
325
326 TF_DISALLOW_COPY_AND_ASSIGN(Node);
327 };
328
329 // Stores debug information associated with the Node.
330 struct NodeDebugInfo {
331 const string name;
332 std::vector<string> original_node_names;
333
334 NodeDebugInfo(const Node& n);
335 NodeDebugInfo(const NodeDef& ndef);
336 NodeDebugInfo(StringPiece node_name, bool has_experimental_debug_info,
337 const NodeDef_ExperimentalDebugInfo& experimental_debug_info);
338 };
339
340 // Represents an input of a node, i.e., the `index`-th input to `node`.
341 struct InputTensor {
342 Node* node;
343 int index;
344
InputTensorInputTensor345 InputTensor(Node* n, int i) : node(n), index(i) {}
InputTensorInputTensor346 InputTensor() : node(nullptr), index(0) {}
347
348 // Returns true if this InputTensor is identical to 'other'. Nodes are
349 // compared using pointer equality.
350 bool operator==(const InputTensor& other) const;
351
352 // A hash function for InputTensors. Nodes are hashed based on their pointer
353 // value.
354 struct Hash {
355 uint64 operator()(InputTensor const& s) const;
356 };
357 };
358
359 // Represents an output of a node, i.e., the `index`-th output of `node`. Note
360 // that a single `OutputTensor` can correspond to multiple `Edge`s if the output
361 // is consumed by multiple destination nodes.
362 struct OutputTensor {
363 Node* node;
364 int index;
365
OutputTensorOutputTensor366 OutputTensor(Node* n, int i) : node(n), index(i) {}
OutputTensorOutputTensor367 OutputTensor() : node(nullptr), index(0) {}
368
369 // Returns true if this OutputTensor is identical to 'other'. Nodes are
370 // compared using pointer equality.
371 bool operator==(const OutputTensor& other) const;
372
373 // A hash function for OutputTensors. Nodes are hashed based on their pointer
374 // value.
375 struct Hash {
376 uint64 operator()(OutputTensor const& s) const;
377 };
378 };
379
380 class Edge {
381 public:
src()382 Node* src() const { return src_; }
dst()383 Node* dst() const { return dst_; }
id()384 int id() const { return id_; }
385
386 // Return the index of the source output that produces the data
387 // carried by this edge. The special value kControlSlot is used
388 // for control dependencies.
src_output()389 int src_output() const { return src_output_; }
390
391 // Return the index of the destination input that consumes the data
392 // carried by this edge. The special value kControlSlot is used
393 // for control dependencies.
dst_input()394 int dst_input() const { return dst_input_; }
395
396 // Return true iff this is an edge that indicates a control-flow
397 // (as opposed to a data-flow) dependency.
398 bool IsControlEdge() const;
399
400 string DebugString() const;
401
402 private:
Edge()403 Edge() {}
404
405 friend class EdgeSetTest;
406 friend class Graph;
407 Node* src_;
408 Node* dst_;
409 int id_;
410 int src_output_;
411 int dst_input_;
412 };
413
414 // Allows for iteration of the edges of a Graph, by iterating the underlying
415 // Graph.edges_ vector while skipping over null entries.
416 class GraphEdgesIterable {
417 private:
418 const std::vector<Edge*>& edges_;
419
420 public:
GraphEdgesIterable(const std::vector<Edge * > & edges)421 explicit GraphEdgesIterable(const std::vector<Edge*>& edges)
422 : edges_(edges) {}
423
424 typedef Edge* value_type;
425
426 class const_iterator {
427 private:
428 // The underlying iterator.
429 std::vector<value_type>::const_iterator iter_;
430
431 // The end of the underlying iterator.
432 std::vector<value_type>::const_iterator end_;
433
434 // Advances iter_ until it reaches a non-null item, or reaches the end.
apply_filter()435 void apply_filter() {
436 while (iter_ != end_ && *iter_ == nullptr) {
437 ++iter_;
438 }
439 }
440
441 public:
const_iterator(std::vector<value_type>::const_iterator iter,std::vector<value_type>::const_iterator end)442 const_iterator(std::vector<value_type>::const_iterator iter,
443 std::vector<value_type>::const_iterator end)
444 : iter_(iter), end_(end) {
445 apply_filter();
446 }
447
448 bool operator==(const const_iterator& other) const {
449 return iter_ == other.iter_;
450 }
451
452 bool operator!=(const const_iterator& other) const {
453 return iter_ != other.iter_;
454 }
455
456 // This is the prefix increment operator (++x), which is the operator
457 // used by C++ range iteration (for (x : y) ...). We intentionally do not
458 // provide a postfix increment operator.
459 const_iterator& operator++() {
460 ++iter_;
461 apply_filter();
462 return *this;
463 }
464
465 value_type operator*() { return *iter_; }
466 };
467
begin()468 const_iterator begin() {
469 return const_iterator(edges_.begin(), edges_.end());
470 }
end()471 const_iterator end() { return const_iterator(edges_.end(), edges_.end()); }
472 };
473
474 // Thread compatible but not thread safe.
475 class Graph {
476 public:
477 // Constructs a graph with a single SOURCE (always id kSourceId) and a
478 // single SINK (always id kSinkId) node, and an edge from SOURCE->SINK.
479 //
480 // The graph can hold ops found in the registry. `ops`s lifetime must be at
481 // least that of the constructed graph's.
482 explicit Graph(const OpRegistryInterface* ops);
483
484 // Constructs a graph with a single SOURCE (always id kSourceId) and a
485 // single SINK (always id kSinkId) node, and an edge from SOURCE->SINK.
486 //
487 // The graph can hold ops found in `flib_def`. Unlike the constructor taking
488 // an OpRegistryInterface, this constructor copies the function definitions in
489 // `flib_def` so its lifetime may be shorter than that of the graph's. The
490 // OpRegistryInterface backing `flib_def` must still have the lifetime of the
491 // graph though.
492 explicit Graph(const FunctionLibraryDefinition& flib_def);
493
494 ~Graph();
495
496 static const int kControlSlot;
497
498 // The GraphDef version range of this graph (see graph.proto).
499 const VersionDef& versions() const;
500 void set_versions(const VersionDef& versions);
501
502 // Adds a new node to this graph, and returns it. Infers the Op and
503 // input/output types for the node. *this owns the returned instance.
504 // Returns nullptr and sets *status on error.
505 Node* AddNode(NodeDef node_def, Status* status);
506
507 // Copies *node, which may belong to another graph, to a new node,
508 // which is returned. Does not copy any edges. *this owns the
509 // returned instance.
510 Node* CopyNode(const Node* node);
511
512 // Removes a node from this graph, including all edges from or to it.
513 // *node should not be accessed after calling this function.
514 // REQUIRES: node->IsOp()
515 void RemoveNode(Node* node);
516
517 // Adds an edge that connects the xth output of `source` to the yth input of
518 // `dest` and returns it. Does not update dest's NodeDef.
519 const Edge* AddEdge(Node* source, int x, Node* dest, int y);
520
521 // Adds a control edge (no data flows along this edge) that connects `source`
522 // to `dest`. If `dest`s NodeDef is missing the corresponding control input,
523 // adds the control input.
524 //
525 // If such a control edge already exists and `allow_duplicates` is false, no
526 // edge is added and the function returns nullptr. Otherwise the edge is
527 // unconditionally created and returned. The NodeDef is not updated if
528 // `allow_duplicates` is true.
529 // TODO(skyewm): // TODO(skyewm): allow_duplicates is needed only by
530 // graph_partition.cc. Figure out if we can do away with it.
531 const Edge* AddControlEdge(Node* source, Node* dest,
532 bool allow_duplicates = false);
533
534 // Removes edge from the graph. Does not update the destination node's
535 // NodeDef.
536 // REQUIRES: The edge must exist.
537 void RemoveEdge(const Edge* edge);
538
539 // Removes control edge `edge` from the graph. Note that this also updates
540 // the corresponding NodeDef to reflect the change.
541 // REQUIRES: The control edge must exist.
542 void RemoveControlEdge(const Edge* e);
543
544 // Updates the input to a node. The existing edge to `dst` is removed and an
545 // edge from `new_src` to `dst` is created. The NodeDef associated with `dst`
546 // is also updated.
547 Status UpdateEdge(Node* new_src, int new_src_index, Node* dst, int dst_index);
548
549 // Like AddEdge but updates dst's NodeDef. Used to add an input edge to a
550 // "While" op during gradient construction, see AddInputWhileHack in
551 // python_api.h for more details.
552 Status AddWhileInputHack(Node* new_src, int new_src_index, Node* dst);
553
554 // Adds the function and gradient definitions in `fdef_lib` to this graph's op
555 // registry. Ignores duplicate functions, and returns a bad status if an
556 // imported function differs from an existing function or op with the same
557 // name.
558 Status AddFunctionLibrary(const FunctionDefLibrary& fdef_lib);
559
560 // The number of live nodes in the graph.
561 //
562 // Because nodes can be removed from the graph, num_nodes() is often
563 // smaller than num_node_ids(). If one needs to create an array of
564 // nodes indexed by node ids, num_node_ids() should be used as the
565 // array's size.
num_nodes()566 int num_nodes() const { return num_nodes_; }
567
568 // The number of live nodes in the graph, excluding the Source and Sink nodes.
num_op_nodes()569 int num_op_nodes() const {
570 DCHECK_GE(num_nodes_, 2);
571 return num_nodes_ - 2;
572 }
573
574 // The number of live edges in the graph.
575 //
576 // Because edges can be removed from the graph, num_edges() is often
577 // smaller than num_edge_ids(). If one needs to create an array of
578 // edges indexed by edge ids, num_edge_ids() should be used as the
579 // array's size.
num_edges()580 int num_edges() const { return num_edges_; }
581
582 // Serialize the nodes starting at `from_node_id` to a GraphDef.
583 void ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const;
584
585 // Serialize to a GraphDef.
586 void ToGraphDef(GraphDef* graph_def) const;
587
588 // This version can be called from debugger to inspect the graph content.
589 // Use the previous version outside debug context for efficiency reasons.
590 //
591 // Note: We do not expose a DebugString() API, since GraphDef.DebugString() is
592 // not defined in some TensorFlow builds.
593 GraphDef ToGraphDefDebug() const;
594
595 // Generate new node name with the specified prefix that is unique
596 // across this graph.
597 string NewName(StringPiece prefix);
598
599 // Access to the list of all nodes. Example usage:
600 // for (Node* node : graph.nodes()) { ... }
601 gtl::iterator_range<NodeIter> nodes() const;
602
603 // Access to the list of all nodes, excluding the Source and Sink nodes.
604 gtl::iterator_range<NodeIter> op_nodes() const;
605
606 // Returns one more than the maximum id assigned to any node.
num_node_ids()607 int num_node_ids() const { return nodes_.size(); }
608
609 // Returns the node associated with an id, or nullptr if no node
610 // with that id (the node with that id was removed and the id has
611 // not yet been re-used). *this owns the returned instance.
612 // REQUIRES: 0 <= id < num_node_ids().
FindNodeId(int id)613 Node* FindNodeId(int id) const { return nodes_[id]; }
614
615 // Returns one more than the maximum id assigned to any edge.
num_edge_ids()616 int num_edge_ids() const { return edges_.size(); }
617
618 // Returns the Edge associated with an id, or nullptr if no edge
619 // with that id (the node with that id was removed and the id has
620 // not yet been re-used). *this owns the returned instance.
621 // REQUIRES: 0 <= id < num_node_ids().
FindEdgeId(int id)622 const Edge* FindEdgeId(int id) const { return edges_[id]; }
623
624 // Access to the set of all edges. Example usage:
625 // for (const Edge* e : graph.edges()) { ... }
edges()626 GraphEdgesIterable edges() const { return GraphEdgesIterable(edges_); }
627
628 // The pre-defined nodes.
629 enum { kSourceId = 0, kSinkId = 1 };
source_node()630 Node* source_node() const { return FindNodeId(kSourceId); }
sink_node()631 Node* sink_node() const { return FindNodeId(kSinkId); }
632
op_registry()633 const OpRegistryInterface* op_registry() const { return &ops_; }
flib_def()634 const FunctionLibraryDefinition& flib_def() const { return ops_; }
635
CheckDeviceNameIndex(int index)636 void CheckDeviceNameIndex(int index) {
637 DCHECK_GE(index, 0);
638 DCHECK_LT(index, static_cast<int>(device_names_.size()));
639 }
640
641 int InternDeviceName(const string& device_name);
642
get_assigned_device_name(const Node & node)643 const string& get_assigned_device_name(const Node& node) const {
644 return device_names_[node.assigned_device_name_index()];
645 }
646
set_assigned_device_name_index(Node * node,int device_name_index)647 void set_assigned_device_name_index(Node* node, int device_name_index) {
648 CheckDeviceNameIndex(device_name_index);
649 node->assigned_device_name_index_ = device_name_index;
650 }
651
set_assigned_device_name(Node * node,const string & device_name)652 void set_assigned_device_name(Node* node, const string& device_name) {
653 node->assigned_device_name_index_ = InternDeviceName(device_name);
654 }
655
656 // Returns OK if `node` is non-null and belongs to this graph
657 Status IsValidNode(const Node* node) const;
658
659 // Returns OK if IsValidNode(`node`) and `idx` is a valid output. Does not
660 // accept control outputs.
661 Status IsValidOutputTensor(const Node* node, int idx) const;
662
663 // Returns OK if IsValidNode(`node`) and `idx` a valid input. Does not accept
664 // control inputs.
665 Status IsValidInputTensor(const Node* node, int idx) const;
666
667 // Create and return a new WhileContext owned by this graph. This is called
668 // when a new while loop is created. `frame_name` must be unique among
669 // WhileContexts in this graph.
670 Status AddWhileContext(StringPiece frame_name, std::vector<Node*> enter_nodes,
671 std::vector<Node*> exit_nodes,
672 OutputTensor cond_output,
673 std::vector<OutputTensor> body_inputs,
674 std::vector<OutputTensor> body_outputs,
675 WhileContext** result);
676
677 // Builds a node name to node pointer index for all nodes in the graph.
678 std::unordered_map<string, Node*> BuildNodeNameIndex() const;
679
680 // TODO(josh11b): uint64 hash() const;
681
682 private:
683 // If cost_node is non-null, then cost accounting (in CostModel)
684 // will be associated with that node rather than the new one being
685 // created.
686 //
687 // Ownership of the returned Node is not transferred to caller.
688 Node* AllocateNode(std::shared_ptr<NodeProperties> props,
689 const Node* cost_node, bool is_function_op);
690 void ReleaseNode(Node* node);
691 // Insert edge in free_edges_ for possible reuse.
692 void RecycleEdge(const Edge* edge);
693 // Registry of all known ops, including functions.
694 FunctionLibraryDefinition ops_;
695
696 // GraphDef versions
697 const std::unique_ptr<VersionDef> versions_;
698
699 // Allocator which will give us good locality.
700 core::Arena arena_;
701
702 // Map from node ids to allocated nodes. nodes_[id] may be nullptr if
703 // the node with that id was removed from the graph.
704 std::vector<Node*> nodes_;
705
706 // Number of nodes alive.
707 int64 num_nodes_ = 0;
708
709 // Map from edge ids to allocated edges. edges_[id] may be nullptr if
710 // the edge with that id was removed from the graph.
711 std::vector<Edge*> edges_;
712
713 // The number of entries in edges_ that are not nullptr.
714 int num_edges_ = 0;
715
716 // Allocated but free nodes and edges.
717 std::vector<Node*> free_nodes_;
718 std::vector<Edge*> free_edges_;
719
720 // For generating unique names.
721 int name_counter_ = 0;
722
723 // In most graphs, the number of unique values used for the
724 // Node::assigned_device_name() property is quite small. If the graph is
725 // large, then this duplication of values can consume a significant amount of
726 // memory. Instead, we represent the same information using an interning
727 // table, which consists of a vector of unique strings (device_names_), as
728 // well a map (device_names_map_) from unique strings to indices within the
729 // unique string table.
730 //
731 // The InternDeviceName() method handles adding a new entry into the table,
732 // or locating the index of an existing entry.
733 //
734 // The fact that Node::assigned_device_name() is implemented using an
735 // interning table is intentionally public. This allows algorithms that
736 // frequently access this field to do so efficiently, especially for the case
737 // where the assigned_device_name of one Node is copied directly from that
738 // of another Node.
739
740 // A table of the unique assigned device names. Indices do NOT correspond
741 // to node IDs. Index 0 is always the empty string.
742 std::vector<string> device_names_;
743
744 // Maps unique device names to indices within device_names_[i].
745 std::unordered_map<string, int> device_names_map_;
746
747 // All the while contexts owned by this graph, keyed by frame name,
748 // corresponding to all the while loops contained in this graph (including
749 // nested loops). The stored contexts are usually accessed via
750 // AddWhileContext() or Node::while_ctx(), but this manages the lifetime.
751 std::map<string, WhileContext> while_ctxs_;
752
753 TF_DISALLOW_COPY_AND_ASSIGN(Graph);
754 };
755
756 // TODO(josh11b): We may want to support keeping an index on various
757 // node/edge attributes in a graph, particularly node names.
758
759 // Helper routines
760
IsSource(const Node * node)761 inline bool IsSource(const Node* node) { return node->IsSource(); }
IsSink(const Node * node)762 inline bool IsSink(const Node* node) { return node->IsSink(); }
IsSwitch(const Node * node)763 inline bool IsSwitch(const Node* node) { return node->IsSwitch(); }
IsMerge(const Node * node)764 inline bool IsMerge(const Node* node) { return node->IsMerge(); }
IsEnter(const Node * node)765 inline bool IsEnter(const Node* node) { return node->IsEnter(); }
IsExit(const Node * node)766 inline bool IsExit(const Node* node) { return node->IsExit(); }
IsNextIteration(const Node * n)767 inline bool IsNextIteration(const Node* n) { return n->IsNextIteration(); }
IsLoopCond(const Node * node)768 inline bool IsLoopCond(const Node* node) { return node->IsLoopCond(); }
IsControlTrigger(const Node * n)769 inline bool IsControlTrigger(const Node* n) { return n->IsControlTrigger(); }
IsSend(const Node * node)770 inline bool IsSend(const Node* node) { return node->IsSend(); }
IsRecv(const Node * node)771 inline bool IsRecv(const Node* node) { return node->IsRecv(); }
IsHostSend(const Node * node)772 inline bool IsHostSend(const Node* node) { return node->IsHostSend(); }
IsHostRecv(const Node * node)773 inline bool IsHostRecv(const Node* node) { return node->IsHostRecv(); }
774
775 // True for Nodes that mediate the transfer of values between processes.
IsTransferNode(const Node * n)776 inline bool IsTransferNode(const Node* n) { return IsSend(n) || IsRecv(n); }
777
IsConstant(const Node * node)778 inline bool IsConstant(const Node* node) { return node->IsConstant(); }
IsVariable(const Node * node)779 inline bool IsVariable(const Node* node) { return node->IsVariable(); }
IsIdentity(const Node * node)780 inline bool IsIdentity(const Node* node) { return node->IsIdentity(); }
781
782 // Returns true iff 'n' is a control flow node.
IsControlFlow(const Node * n)783 inline bool IsControlFlow(const Node* n) { return n->IsControlFlow(); }
784
785 // Returns true if the node only depends on its input's metadata
786 // (shape). Specifically, returns true for "Size", "Shape" and "Rank" ops.
IsMetadata(const Node * n)787 inline bool IsMetadata(const Node* n) { return n->IsMetadata(); }
788
IsScopedAllocator(const Node * n)789 inline bool IsScopedAllocator(const Node* n) { return n->IsScopedAllocator(); }
790
IsHostMemoryPreserving(const Node * node)791 inline bool IsHostMemoryPreserving(const Node* node) {
792 return IsIdentity(node) || IsControlFlow(node);
793 }
794
795 // NOTE: We declare Reference type of NodeIter and NeighborIter as Node* (see
796 // https://en.cppreference.com/w/cpp/iterator/iterator).
797
798 // Iterator for stepping through the nodes of a graph.
799 class NodeIter
800 : public std::iterator<std::forward_iterator_tag, Node, std::ptrdiff_t,
801 /*Pointer*/ Node*, /*Reference*/ Node*> {
802 public:
803 NodeIter(const Graph* graph, int id);
804 bool operator==(const NodeIter& rhs) const;
805 bool operator!=(const NodeIter& rhs) const;
806 void operator++();
807 reference operator*() const;
808 pointer operator->() const;
809
810 private:
811 // Invariant: id_ == graph_->num_node_ids() || graph_->FindId(id_) != nullptr
812 const Graph* graph_;
813 int id_;
814 };
815
816 // Iterator for stepping through the neighbors of a node.
817 class NeighborIter
818 : public std::iterator<std::forward_iterator_tag, Node, std::ptrdiff_t,
819 /*Pointer*/ Node*, /*Reference*/ Node*> {
820 public:
821 NeighborIter(EdgeSet::const_iterator iter, bool incoming);
822 bool operator==(const NeighborIter& rhs) const;
823 bool operator!=(const NeighborIter& rhs) const;
824 void operator++();
825 reference operator*() const;
826 pointer operator->() const;
827
828 private:
829 EdgeSet::const_iterator iter_;
830 bool incoming_;
831 };
832
833 // IMPLEMENTATION DETAILS, PLEASE IGNORE
834
NodeIter(const Graph * graph,int id)835 inline NodeIter::NodeIter(const Graph* graph, int id)
836 : graph_(graph), id_(id) {}
837
838 inline bool NodeIter::operator==(const NodeIter& rhs) const {
839 DCHECK(graph_ == rhs.graph_);
840 return id_ == rhs.id_;
841 }
842
843 inline bool NodeIter::operator!=(const NodeIter& rhs) const {
844 return !(*this == rhs);
845 }
846
847 inline void NodeIter::operator++() {
848 while (1) {
849 DCHECK_LE(id_, graph_->num_node_ids());
850 ++id_;
851 if (id_ >= graph_->num_node_ids() || graph_->FindNodeId(id_) != nullptr) {
852 return;
853 }
854 }
855 }
856
857 inline Node* NodeIter::operator*() const { return graph_->FindNodeId(id_); }
858
859 inline Node* NodeIter::operator->() const { return graph_->FindNodeId(id_); }
860
NeighborIter(EdgeSet::const_iterator iter,bool incoming)861 inline NeighborIter::NeighborIter(EdgeSet::const_iterator iter, bool incoming)
862 : iter_(iter), incoming_(incoming) {}
863
864 inline bool NeighborIter::operator==(const NeighborIter& rhs) const {
865 return iter_ == rhs.iter_ && incoming_ == rhs.incoming_;
866 }
867
868 inline bool NeighborIter::operator!=(const NeighborIter& rhs) const {
869 return !(*this == rhs);
870 }
871
872 inline void NeighborIter::operator++() { ++iter_; }
873
874 inline Node* NeighborIter::operator*() const {
875 const Edge* e = *iter_;
876 return incoming_ ? e->src() : e->dst();
877 }
878
879 inline Node* NeighborIter::operator->() const {
880 const Edge* e = *iter_;
881 return incoming_ ? e->src() : e->dst();
882 }
883
IsControlEdge()884 inline bool Edge::IsControlEdge() const {
885 // Note that if either src_output_ or dst_input_ is kControlSlot,
886 // so is the other one (AddEdge checks this).
887 return src_output_ == Graph::kControlSlot;
888 }
889
nodes()890 inline gtl::iterator_range<NodeIter> Graph::nodes() const {
891 // Note that NodeId 0 is always valid since we don't let the source
892 // node be removed from the graph.
893 return gtl::make_range(NodeIter(this, 0), NodeIter(this, num_node_ids()));
894 }
895
op_nodes()896 inline gtl::iterator_range<NodeIter> Graph::op_nodes() const {
897 // Note that NodeId 0 is always valid since we don't let the source
898 // node be removed from the graph.
899 //
900 // The current implementation of Graph maintains the invariant that the
901 // first two nodes are the source and sink nodes, and all other nodes are op
902 // nodes. This method (op_nodes()) relies on this invariant.
903 NodeIter begin(this, 0);
904 NodeIter end(this, num_node_ids());
905 if (begin != end) {
906 ++begin;
907 }
908 if (begin != end) {
909 ++begin;
910 }
911 return gtl::make_range(begin, end);
912 }
913
set_assigned_device_name_index(int index)914 inline void Node::set_assigned_device_name_index(int index) {
915 graph_->CheckDeviceNameIndex(index);
916 assigned_device_name_index_ = index;
917 }
918
set_assigned_device_name(const string & device_name)919 inline void Node::set_assigned_device_name(const string& device_name) {
920 graph_->set_assigned_device_name(this, device_name);
921 }
922
assigned_device_name()923 inline const string& Node::assigned_device_name() const {
924 return graph_->get_assigned_device_name(*this);
925 }
926
927 } // namespace tensorflow
928
929 #endif // TENSORFLOW_CORE_GRAPH_GRAPH_H_
930