1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // A Graph describes a set of computations that are to be
17 // performed, as well as the dependencies between those
18 // computations. The basic model is a DAG (directed acyclic graph) with
19 // * internal nodes representing computational operations to be performed;
20 // * edges represent dependencies, indicating the target may only be
21 // executed once the source has completed; and
22 // * predefined "source" (start) and "sink" (finish) nodes -- the source
23 // should be the only node that doesn't depend on anything, and the sink
24 // should be the only node that nothing depends on.
25 //
26 // Note: Node ids are intended to be relatively dense in the
27 // 0..max_id range, but there may be gaps since ids won't be reused.
28 //
29 // Note: Some dependencies between operations are due to one operation
30 // consuming the output of another. In fact operations can produce
31 // multiple outputs and consume multiple inputs, and some
32 // optimizations will care about which specific outputs are connected
33 // to which specific inputs. We therefore represent data dependency
34 // between output O of layer A and input I of layer B using
35 // "input index" and "output index" labels per edge.
36
37 #ifndef TENSORFLOW_GRAPH_GRAPH_H_
38 #define TENSORFLOW_GRAPH_GRAPH_H_
39
40 #include <functional>
41 #include <string>
42 #include <vector>
43 #include "tensorflow/core/framework/function.h"
44 #include "tensorflow/core/framework/op.h"
45 #include "tensorflow/core/framework/types.h"
46 #include "tensorflow/core/graph/edgeset.h"
47 #include "tensorflow/core/lib/core/arena.h"
48 #include "tensorflow/core/lib/core/refcount.h"
49 #include "tensorflow/core/lib/core/status.h"
50 #include "tensorflow/core/lib/gtl/iterator_range.h"
51 #include "tensorflow/core/platform/logging.h"
52 #include "tensorflow/core/platform/macros.h"
53 #include "tensorflow/core/platform/types.h"
54
55 namespace tensorflow {
56
57 class Edge;
58 class EdgeSetTest;
59 class Graph;
60 class GraphDef;
61 class Node;
62 class VersionDef;
63 class WhileContext;
64
65 class NeighborIter; // Declared below
66 class NodeIter; // Declared below
67 class NodeProperties; // Defined in .cc
68
69 class Node {
70 public:
71 string DebugString() const;
id()72 int id() const { return id_; }
cost_id()73 int cost_id() const { return cost_id_; }
74 const string& name() const;
75 const string& type_string() const;
76
77 // def() provides the NodeDef the user supplied, but the specifics
78 // of this Node may have changed due to placement, optimization, etc.
79 // In particular:
80 // * def().name() will match name();
81 // * def().op() will match type_string() and op_def().name();
82 // * def().input() is not reliable, use "in_edges()" below instead;
83 // * def().device() is the "user's requested device" and may not match
84 // the actual assigned device, see assigned_device_name() below;
85 // * def().attr() is authoritative.
86 // TODO(irving): Replace with NodeInfo.
87 const NodeDef& def() const;
88 const OpDef& op_def() const;
89
90 // input and output types
91 int32 num_inputs() const;
92 DataType input_type(int32 i) const;
93 const DataTypeVector& input_types() const;
94
95 int32 num_outputs() const;
96 DataType output_type(int32 o) const;
97 const DataTypeVector& output_types() const;
98
99 // The device requested by the user. For the actual assigned device,
100 // use assigned_device_name() below.
101 const string& requested_device() const;
102
103 // This changes the user requested device but not necessarily the device that
104 // on which the operation will run.
105 void set_requested_device(const string& device);
106
107 // This gives the device the runtime has assigned this node to. If
108 // you want the device the user requested, use def().device() instead.
109 // TODO(josh11b): Validate that the assigned_device, if not empty:
110 // fully specifies a device, and satisfies def().device().
111 // TODO(josh11b): Move assigned_device_name outside of Node into a
112 // NodeId->DeviceName map.
113 const string& assigned_device_name() const;
114 void set_assigned_device_name(const string& device_name);
has_assigned_device_name()115 bool has_assigned_device_name() const {
116 return assigned_device_name_index_ > 0;
117 }
assigned_device_name_index()118 int assigned_device_name_index() const { return assigned_device_name_index_; }
119 void set_assigned_device_name_index(int index);
120
121 // Read only access to attributes
122 AttrSlice attrs() const;
123
124 // Inputs requested by the NodeDef. For the actual inputs, use in_edges.
125 const protobuf::RepeatedPtrField<string>& requested_inputs() const;
126
127 // Get the neighboring nodes via edges either in or out of this node.
128 gtl::iterator_range<NeighborIter> in_nodes() const;
129 gtl::iterator_range<NeighborIter> out_nodes() const;
in_edges()130 const EdgeSet& in_edges() const { return in_edges_; }
out_edges()131 const EdgeSet& out_edges() const { return out_edges_; }
132
133 // Node type helpers.
IsSource()134 bool IsSource() const { return id() == 0; }
IsSink()135 bool IsSink() const { return id() == 1; }
136 // Anything other than the special Source & Sink nodes.
IsOp()137 bool IsOp() const { return id() > 1; }
138
139 // Node class helpers
IsSwitch()140 bool IsSwitch() const { return class_ == NC_SWITCH; }
IsMerge()141 bool IsMerge() const { return class_ == NC_MERGE; }
IsEnter()142 bool IsEnter() const { return class_ == NC_ENTER; }
IsExit()143 bool IsExit() const { return class_ == NC_EXIT; }
IsNextIteration()144 bool IsNextIteration() const { return class_ == NC_NEXT_ITERATION; }
IsLoopCond()145 bool IsLoopCond() const { return class_ == NC_LOOP_COND; }
IsControlTrigger()146 bool IsControlTrigger() const { return class_ == NC_CONTROL_TRIGGER; }
IsSend()147 bool IsSend() const { return class_ == NC_SEND || class_ == NC_HOST_SEND; }
IsRecv()148 bool IsRecv() const { return class_ == NC_RECV || class_ == NC_HOST_RECV; }
IsConstant()149 bool IsConstant() const { return class_ == NC_CONSTANT; }
IsVariable()150 bool IsVariable() const { return class_ == NC_VARIABLE; }
IsIdentity()151 bool IsIdentity() const { return class_ == NC_IDENTITY; }
IsGetSessionHandle()152 bool IsGetSessionHandle() const { return class_ == NC_GET_SESSION_HANDLE; }
IsGetSessionTensor()153 bool IsGetSessionTensor() const { return class_ == NC_GET_SESSION_TENSOR; }
IsDeleteSessionTensor()154 bool IsDeleteSessionTensor() const {
155 return class_ == NC_DELETE_SESSION_TENSOR;
156 }
IsControlFlow()157 bool IsControlFlow() const {
158 return (class_ != NC_OTHER) && // Fast path
159 (IsSwitch() || IsMerge() || IsEnter() || IsExit() ||
160 IsNextIteration());
161 }
IsHostSend()162 bool IsHostSend() const { return class_ == NC_HOST_SEND; }
IsHostRecv()163 bool IsHostRecv() const { return class_ == NC_HOST_RECV; }
164
IsMetadata()165 bool IsMetadata() const { return class_ == NC_METADATA; }
166
167 template <typename T>
AddAttr(const string & name,const T & val)168 void AddAttr(const string& name, const T& val) {
169 SetAttrValue(val, AddAttrHelper(name));
170 }
171
172 void ClearAttr(const string& name);
173
174 // Returns into '*e' the edge connecting to the 'idx' input of this Node.
175 Status input_edge(int idx, const Edge** e) const;
176
177 // Returns into '*edges' the input data edges of this Node, indexed by input
178 // number. Does not return control edges.
179 Status input_edges(std::vector<const Edge*>* edges) const;
180
181 // Returns into '*n' the node that has an output connected to the
182 // 'idx' input of this Node.
183 Status input_node(int idx, const Node** n) const;
184 Status input_node(int idx, Node** n) const;
185
while_ctx()186 WhileContext* while_ctx() const { return while_ctx_; }
set_while_ctx(WhileContext * while_ctx)187 void set_while_ctx(WhileContext* while_ctx) {
188 DCHECK(IsExit());
189 DCHECK(while_ctx_ == nullptr);
190 while_ctx_ = while_ctx;
191 }
192
193 private:
194 friend class Graph;
195 Node();
196
properties()197 NodeProperties* properties() const { return props_.get(); }
198
199 void Initialize(int id, int cost_id, std::shared_ptr<NodeProperties> props);
200
201 // Releases memory from props_, in addition to restoring *this to its
202 // uninitialized state.
203 void Clear();
204
205 // Make a copy of the Node's props_ if props_ is shared with
206 // other nodes. This must be called before mutating properties,
207 // e.g. in AddAttr.
208 void MaybeCopyOnWrite();
209
210 AttrValue* AddAttrHelper(const string& name);
211
212 // A set of mutually exclusive classes for different kinds of nodes,
213 // class_ is initialized in the Node::Initialize routine based on the
214 // node's type_string().
215 enum NodeClass {
216 NC_UNINITIALIZED,
217 NC_SWITCH,
218 NC_MERGE,
219 NC_ENTER,
220 NC_EXIT,
221 NC_NEXT_ITERATION,
222 NC_LOOP_COND,
223 NC_CONTROL_TRIGGER,
224 NC_SEND,
225 NC_HOST_SEND,
226 NC_RECV,
227 NC_HOST_RECV,
228 NC_CONSTANT,
229 NC_VARIABLE,
230 NC_IDENTITY,
231 NC_GET_SESSION_HANDLE,
232 NC_GET_SESSION_TENSOR,
233 NC_DELETE_SESSION_TENSOR,
234 NC_METADATA,
235 NC_OTHER // Not a special kind of node
236 };
237
238 static const std::unordered_map<string, NodeClass>& kNodeClassTable;
239
240 static NodeClass GetNodeClassForOp(const string& ts);
241
242 int id_; // -1 until Initialize() is called
243 int cost_id_; // -1 if there is no corresponding cost accounting node
244 NodeClass class_;
245
246 EdgeSet in_edges_;
247 EdgeSet out_edges_;
248
249 // NOTE(skyewm): inheriting from core::RefCounted may have a slight
250 // performance benefit over using shared_ptr, at the cost of manual ref
251 // counting
252 std::shared_ptr<NodeProperties> props_;
253
254 // Index within Graph::device_names_ of the name of device assigned
255 // to perform this computation.
256 int assigned_device_name_index_;
257
258 // A back-pointer to the Graph that owns this node. Currently, this exists
259 // solely to allow Node::[set_]assigned_device_name() to work. However, if all
260 // callers of Node::[set_]assigned_device_name() are modified to use the
261 // equivalent methods defined directly on Graph, then we can remove this
262 // field and reclaim that memory.
263 Graph* graph_;
264
265 // Set if this is an exit node of a while loop with an associated
266 // WhileContext. Otherwise null. (This is only set for exit nodes because
267 // they're the first nodes of a loop encountered while creating the gradient
268 // graph. Exit nodes that are part of while loop gradient graphs will not have
269 // this set.)
270 WhileContext* while_ctx_;
271
272 TF_DISALLOW_COPY_AND_ASSIGN(Node);
273 };
274
275 // Represents an input of a node, i.e., the `index`-th input to `node`.
276 struct InputTensor {
277 const Node* node;
278 int index;
279
InputTensorInputTensor280 InputTensor(const Node* n, int i) : node(n), index(i) {}
InputTensorInputTensor281 InputTensor() : node(nullptr), index(0) {}
282 };
283
284 // Represents an output of a node, i.e., the `index`-th output of `node`. Note
285 // that a single `OutputTensor` can correspond to multiple `Edge`s if the output
286 // is consumed by multiple destination nodes.
287 struct OutputTensor {
288 const Node* node;
289 int index;
290
OutputTensorOutputTensor291 OutputTensor(const Node* n, int i) : node(n), index(i) {}
OutputTensorOutputTensor292 OutputTensor() : node(nullptr), index(0) {}
293 };
294
295 class Edge {
296 public:
src()297 Node* src() const { return src_; }
dst()298 Node* dst() const { return dst_; }
id()299 int id() const { return id_; }
300
301 // Return the index of the source output that produces the data
302 // carried by this edge. The special value kControlSlot is used
303 // for control dependencies.
src_output()304 int src_output() const { return src_output_; }
305
306 // Return the index of the destination input that consumes the data
307 // carried by this edge. The special value kControlSlot is used
308 // for control dependencies.
dst_input()309 int dst_input() const { return dst_input_; }
310
311 // Return true iff this is an edge that indicates a control-flow
312 // (as opposed to a data-flow) dependency.
313 bool IsControlEdge() const;
314
315 string DebugString() const;
316
317 private:
Edge()318 Edge() {}
319
320 friend class EdgeSetTest;
321 friend class Graph;
322 Node* src_;
323 Node* dst_;
324 int id_;
325 int src_output_;
326 int dst_input_;
327 };
328
329 // Allows for iteration of the edges of a Graph, by iterating the underlying
330 // Graph.edges_ vector while skipping over null entries.
331 class GraphEdgesIterable {
332 private:
333 const std::vector<Edge*>& edges_;
334
335 public:
GraphEdgesIterable(const std::vector<Edge * > & edges)336 explicit GraphEdgesIterable(const std::vector<Edge*>& edges)
337 : edges_(edges) {}
338
339 typedef Edge* value_type;
340
341 class const_iterator {
342 private:
343 // The underlying iterator.
344 std::vector<value_type>::const_iterator iter_;
345
346 // The end of the underlying iterator.
347 std::vector<value_type>::const_iterator end_;
348
349 // Advances iter_ until it reaches a non-null item, or reaches the end.
apply_filter()350 void apply_filter() {
351 while (iter_ != end_ && *iter_ == nullptr) {
352 ++iter_;
353 }
354 }
355
356 public:
const_iterator(std::vector<value_type>::const_iterator iter,std::vector<value_type>::const_iterator end)357 const_iterator(std::vector<value_type>::const_iterator iter,
358 std::vector<value_type>::const_iterator end)
359 : iter_(iter), end_(end) {
360 apply_filter();
361 }
362
363 bool operator==(const const_iterator& other) const {
364 return iter_ == other.iter_;
365 }
366
367 bool operator!=(const const_iterator& other) const {
368 return iter_ != other.iter_;
369 }
370
371 // This is the prefix increment operator (++x), which is the operator
372 // used by C++ range iteration (for (x : y) ...). We intentionally do not
373 // provide a postfix increment operator.
374 const_iterator& operator++() {
375 ++iter_;
376 apply_filter();
377 return *this;
378 }
379
380 value_type operator*() { return *iter_; }
381 };
382
begin()383 const_iterator begin() {
384 return const_iterator(edges_.begin(), edges_.end());
385 }
end()386 const_iterator end() { return const_iterator(edges_.end(), edges_.end()); }
387 };
388
389 // Thread compatible but not thread safe.
390 class Graph {
391 public:
392 // Constructs a graph with a single SOURCE (always id kSourceId) and a
393 // single SINK (always id kSinkId) node, and an edge from SOURCE->SINK.
394 //
395 // The graph can hold ops found in registry. `registry`s lifetime must be at
396 // least that of the constructed graph's.
397 explicit Graph(const OpRegistryInterface* registry);
398
399 // Constructs a graph with a single SOURCE (always id kSourceId) and a
400 // single SINK (always id kSinkId) node, and an edge from SOURCE->SINK.
401 //
402 // The graph can hold ops found in `flib_def`. Unlike the constructor taking
403 // an OpRegistryInterface, this constructor copies the function definitions in
404 // `flib_def` so its lifetime may be shorter than that of the graph's. The
405 // OpRegistryInterface backing `flib_def` must still have the lifetime of the
406 // graph though.
407 explicit Graph(const FunctionLibraryDefinition& flib_def);
408
409 ~Graph();
410
411 static const int kControlSlot;
412
413 // The GraphDef version range of this graph (see graph.proto).
414 const VersionDef& versions() const;
415 void set_versions(const VersionDef& versions);
416
417 // Adds a new node to this graph, and returns it. Infers the Op and
418 // input/output types for the node. *this owns the returned instance.
419 // Returns nullptr and sets *status on error.
420 Node* AddNode(const NodeDef& node_def, Status* status);
421
422 // Copies *node, which may belong to another graph, to a new node,
423 // which is returned. Does not copy any edges. *this owns the
424 // returned instance.
425 Node* CopyNode(Node* node);
426
427 // Removes a node from this graph, including all edges from or to it.
428 // *node should not be accessed after calling this function.
429 // REQUIRES: node->IsOp()
430 void RemoveNode(Node* node);
431
432 // Adds an edge that connects the xth output of `source` to the yth input of
433 // `dest` and returns it. Does not update dest's NodeDef.
434 const Edge* AddEdge(Node* source, int x, Node* dest, int y);
435
436 // Adds a control edge (no data flows along this edge) that connects `source`
437 // to `dest`. If `dest`s NodeDef is missing the corresponding control input,
438 // adds the control input.
439 //
440 // If such a control edge already exists and `allow_duplicates` is false, no
441 // edge is added and the function returns nullptr. Otherwise the edge is
442 // unconditionally created and returned. The NodeDef is not updated if
443 // `allow_duplicates` is true.
444 // TODO(skyewm): // TODO(skyewm): allow_duplicates is needed only by
445 // graph_partition.cc. Figure out if we can do away with it.
446 const Edge* AddControlEdge(Node* source, Node* dest,
447 bool allow_duplicates = false);
448
449 // Removes edge from the graph. Does not update the destination node's
450 // NodeDef.
451 // REQUIRES: The edge must exist.
452 void RemoveEdge(const Edge* edge);
453
454 // Removes control edge `edge` from the graph. Note that this also updates
455 // the corresponding NodeDef to reflect the change.
456 // REQUIRES: The control edge must exist.
457 void RemoveControlEdge(const Edge* e);
458 // Updates the input to a node. The existing edge to `dst` is removed and an
459 // edge from `new_src` to `dst` is created. The NodeDef associated with `dst`
460 // is also updated.
461 Status UpdateEdge(Node* new_src, int new_src_index, Node* dst, int dst_index);
462
463 // Adds the function and gradient definitions in `fdef_lib` to this graph's op
464 // registry. Ignores duplicate functions, and returns a bad status if an
465 // imported function differs from an existing function or op with the same
466 // name.
467 Status AddFunctionLibrary(const FunctionDefLibrary& fdef_lib);
468
469 // The number of live nodes in the graph.
470 //
471 // Because nodes can be removed from the graph, num_nodes() is often
472 // smaller than num_node_ids(). If one needs to create an array of
473 // nodes indexed by node ids, num_node_ids() should be used as the
474 // array's size.
num_nodes()475 int num_nodes() const { return num_nodes_; }
476
477 // The number of live nodes in the graph, excluding the Source and Sink nodes.
num_op_nodes()478 int num_op_nodes() const {
479 DCHECK_GE(num_nodes_, 2);
480 return num_nodes_ - 2;
481 }
482
483 // The number of live edges in the graph.
484 //
485 // Because edges can be removed from the graph, num_edges() is often
486 // smaller than num_edge_ids(). If one needs to create an array of
487 // edges indexed by edge ids, num_edge_ids() should be used as the
488 // array's size.
num_edges()489 int num_edges() const { return num_edges_; }
490
491 // Serialize the nodes starting at `from_node_id` to a GraphDef.
492 void ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const;
493
494 // Serialize to a GraphDef.
495 void ToGraphDef(GraphDef* graph_def) const;
496
497 // This version can be called from debugger to inspect the graph content.
498 // Use the previous version outside debug context for efficiency reasons.
499 //
500 // Note: We do not expose a DebugString() API, since GraphDef.DebugString() is
501 // not defined in some TensorFlow builds.
502 GraphDef ToGraphDefDebug() const;
503
504 // Generate new node name with the specified prefix that is unique
505 // across this graph.
506 string NewName(StringPiece prefix);
507
508 // Access to the list of all nodes. Example usage:
509 // for (Node* node : graph.nodes()) { ... }
510 gtl::iterator_range<NodeIter> nodes() const;
511
512 // Access to the list of all nodes, excluding the Source and Sink nodes.
513 gtl::iterator_range<NodeIter> op_nodes() const;
514
515 // Returns one more than the maximum id assigned to any node.
num_node_ids()516 int num_node_ids() const { return nodes_.size(); }
517
518 // Returns the node associated with an id, or nullptr if no node
519 // with that id (the node with that id was removed and the id has
520 // not yet been re-used). *this owns the returned instance.
521 // REQUIRES: 0 <= id < num_node_ids().
FindNodeId(int id)522 Node* FindNodeId(int id) const { return nodes_[id]; }
523
524 // Returns one more than the maximum id assigned to any edge.
num_edge_ids()525 int num_edge_ids() const { return edges_.size(); }
526
527 // Returns the Edge associated with an id, or nullptr if no edge
528 // with that id (the node with that id was removed and the id has
529 // not yet been re-used). *this owns the returned instance.
530 // REQUIRES: 0 <= id < num_node_ids().
FindEdgeId(int id)531 const Edge* FindEdgeId(int id) const { return edges_[id]; }
532
533 // Access to the set of all edges. Example usage:
534 // for (const Edge* e : graph.edges()) { ... }
edges()535 GraphEdgesIterable edges() const { return GraphEdgesIterable(edges_); }
536
537 // The pre-defined nodes.
538 enum { kSourceId = 0, kSinkId = 1 };
source_node()539 Node* source_node() const { return FindNodeId(kSourceId); }
sink_node()540 Node* sink_node() const { return FindNodeId(kSinkId); }
541
op_registry()542 const OpRegistryInterface* op_registry() const { return &ops_; }
flib_def()543 const FunctionLibraryDefinition& flib_def() const { return ops_; }
544
CheckDeviceNameIndex(int index)545 void CheckDeviceNameIndex(int index) {
546 DCHECK_GE(index, 0);
547 DCHECK_LT(index, static_cast<int>(device_names_.size()));
548 }
549
550 int InternDeviceName(const string& device_name);
551
get_assigned_device_name(const Node & node)552 const string& get_assigned_device_name(const Node& node) const {
553 return device_names_[node.assigned_device_name_index()];
554 }
555
set_assigned_device_name_index(Node * node,int device_name_index)556 void set_assigned_device_name_index(Node* node, int device_name_index) {
557 CheckDeviceNameIndex(device_name_index);
558 node->assigned_device_name_index_ = device_name_index;
559 }
560
set_assigned_device_name(Node * node,const string & device_name)561 void set_assigned_device_name(Node* node, const string& device_name) {
562 node->assigned_device_name_index_ = InternDeviceName(device_name);
563 }
564
565 // Returns OK if `node` is non-null and belongs to this graph
566 Status IsValidNode(const Node* node) const;
567
568 // Returns OK if IsValidNode(`node`) and `idx` is less than
569 // node->num_outputs()
570 Status IsValidOutputTensor(const Node* node, int idx) const;
571
572 // Returns OK if IsValidNode(`node`) and `idx` is less than
573 // node->num_inputs()
574 Status IsValidInputTensor(const Node* node, int idx) const;
575
576 // Create and return a new WhileContext owned by this graph. This is called
577 // when a new while loop is created. `frame_name` must be unique among
578 // WhileContexts in this graph.
579 Status AddWhileContext(StringPiece frame_name, std::vector<Node*> enter_nodes,
580 std::vector<Node*> exit_nodes,
581 OutputTensor cond_output,
582 std::vector<OutputTensor> body_inputs,
583 std::vector<OutputTensor> body_outputs,
584 WhileContext** result);
585
586 // TODO(josh11b): uint64 hash() const;
587
588 private:
589 // If cost_node is non-null, then cost accounting (in CostModel)
590 // will be associated with that node rather than the new one being
591 // created.
592 //
593 // Ownership of the returned Node is not transferred to caller.
594 Node* AllocateNode(std::shared_ptr<NodeProperties> props,
595 const Node* cost_node);
596 void ReleaseNode(Node* node);
597
598 // Registry of all known ops, including functions.
599 FunctionLibraryDefinition ops_;
600
601 // GraphDef versions
602 const std::unique_ptr<VersionDef> versions_;
603
604 // Allocator which will give us good locality.
605 core::Arena arena_;
606
607 // Map from node ids to allocated nodes. nodes_[id] may be nullptr if
608 // the node with that id was removed from the graph.
609 std::vector<Node*> nodes_;
610
611 // Number of nodes alive.
612 int64 num_nodes_ = 0;
613
614 // Map from edge ids to allocated edges. edges_[id] may be nullptr if
615 // the edge with that id was removed from the graph.
616 std::vector<Edge*> edges_;
617
618 // The number of entries in edges_ that are not nullptr.
619 int num_edges_ = 0;
620
621 // Allocated but free nodes and edges.
622 std::vector<Node*> free_nodes_;
623 std::vector<Edge*> free_edges_;
624
625 // For generating unique names.
626 int name_counter_ = 0;
627
628 // In most graphs, the number of unique values used for the
629 // Node::assigned_device_name() property is quite small. If the graph is
630 // large, then this duplication of values can consume a significant amount of
631 // memory. Instead, we represent the same information using an interning
632 // table, which consists of a vector of unique strings (device_names_), as
633 // well a map (device_names_map_) from unique strings to indices within the
634 // unique string table.
635 //
636 // The InternDeviceName() method handles adding a new entry into the table,
637 // or locating the index of an existing entry.
638 //
639 // The fact that Node::assigned_device_name() is implemented using an
640 // interning table is intentionally public. This allows algorithms that
641 // frequently access this field to do so efficiently, especially for the case
642 // where the assigned_device_name of one Node is copied directly from that
643 // of another Node.
644
645 // A table of the unique assigned device names. Indices do NOT correspond
646 // to node IDs. Index 0 is always the empty string.
647 std::vector<string> device_names_;
648
649 // Maps unique device names to indices within device_names_[i].
650 std::unordered_map<string, int> device_names_map_;
651
652 // All the while contexts owned by this graph, keyed by frame name,
653 // corresponding to all the while loops contained in this graph (including
654 // nested loops). The stored contexts are usually accessed via
655 // AddWhileContext() or Node::while_ctx(), but this manages the lifetime.
656 std::map<string, WhileContext> while_ctxs_;
657
658 // Searches through edges_ for the Edge whose destination node and index
659 // matches dst. An edge with destination `dst` must exist in the graph.
660 const Edge* FindEdge(const Node* dst, int index);
661
662 TF_DISALLOW_COPY_AND_ASSIGN(Graph);
663 };
664
665 // TODO(josh11b): We may want to support keeping an index on various
666 // node/edge attributes in a graph, particularly node names.
667
668 // Helper routines
669
IsSource(const Node * node)670 inline bool IsSource(const Node* node) { return node->IsSource(); }
IsSink(const Node * node)671 inline bool IsSink(const Node* node) { return node->IsSink(); }
IsSwitch(const Node * node)672 inline bool IsSwitch(const Node* node) { return node->IsSwitch(); }
IsMerge(const Node * node)673 inline bool IsMerge(const Node* node) { return node->IsMerge(); }
IsEnter(const Node * node)674 inline bool IsEnter(const Node* node) { return node->IsEnter(); }
IsExit(const Node * node)675 inline bool IsExit(const Node* node) { return node->IsExit(); }
IsNextIteration(const Node * n)676 inline bool IsNextIteration(const Node* n) { return n->IsNextIteration(); }
IsLoopCond(const Node * node)677 inline bool IsLoopCond(const Node* node) { return node->IsLoopCond(); }
IsControlTrigger(const Node * n)678 inline bool IsControlTrigger(const Node* n) { return n->IsControlTrigger(); }
IsSend(const Node * node)679 inline bool IsSend(const Node* node) { return node->IsSend(); }
IsRecv(const Node * node)680 inline bool IsRecv(const Node* node) { return node->IsRecv(); }
IsHostSend(const Node * node)681 inline bool IsHostSend(const Node* node) { return node->IsHostSend(); }
IsHostRecv(const Node * node)682 inline bool IsHostRecv(const Node* node) { return node->IsHostRecv(); }
683
684 // True for Nodes that mediate the transfer of values between processes.
IsTransferNode(const Node * n)685 inline bool IsTransferNode(const Node* n) { return IsSend(n) || IsRecv(n); }
686
IsConstant(const Node * node)687 inline bool IsConstant(const Node* node) { return node->IsConstant(); }
IsVariable(const Node * node)688 inline bool IsVariable(const Node* node) { return node->IsVariable(); }
IsIdentity(const Node * node)689 inline bool IsIdentity(const Node* node) { return node->IsIdentity(); }
690
691 // Returns true iff 'n' is a control flow node.
IsControlFlow(const Node * n)692 inline bool IsControlFlow(const Node* n) { return n->IsControlFlow(); }
693
694 // Returns true if the node only depends on its input's metadata
695 // (shape). Specifically, returns true for "Size", "Shape" and "Rank" ops.
IsMetadata(const Node * n)696 inline bool IsMetadata(const Node* n) { return n->IsMetadata(); }
697
IsHostMemoryPreserving(const Node * node)698 inline bool IsHostMemoryPreserving(const Node* node) {
699 return IsIdentity(node) || IsControlFlow(node);
700 }
701
702 // Iterator for stepping through the nodes of a graph.
703 class NodeIter {
704 public:
705 NodeIter(const Graph* graph, int id);
706 bool operator==(const NodeIter& rhs);
707 bool operator!=(const NodeIter& rhs);
708 void operator++();
709 Node* operator*();
710 Node* operator->();
711
712 private:
713 // Invariant: id_ == graph_->num_node_ids() || graph_->FindId(id_) != nullptr
714 const Graph* graph_;
715 int id_;
716 };
717
718 // Iterator for stepping through the neighbors of a node.
719 class NeighborIter {
720 public:
721 NeighborIter(EdgeSet::const_iterator iter, bool incoming);
722 bool operator==(const NeighborIter& rhs);
723 bool operator!=(const NeighborIter& rhs);
724 void operator++();
725 Node* operator*();
726 Node* operator->();
727
728 private:
729 EdgeSet::const_iterator iter_;
730 bool incoming_;
731 };
732
733 // IMPLEMENTATION DETAILS, PLEASE IGNORE
734
NodeIter(const Graph * graph,int id)735 inline NodeIter::NodeIter(const Graph* graph, int id)
736 : graph_(graph), id_(id) {}
737
738 inline bool NodeIter::operator==(const NodeIter& rhs) {
739 DCHECK(graph_ == rhs.graph_);
740 return id_ == rhs.id_;
741 }
742
743 inline bool NodeIter::operator!=(const NodeIter& rhs) {
744 return !(*this == rhs);
745 }
746
747 inline void NodeIter::operator++() {
748 while (1) {
749 DCHECK_LE(id_, graph_->num_node_ids());
750 ++id_;
751 if (id_ >= graph_->num_node_ids() || graph_->FindNodeId(id_) != nullptr) {
752 return;
753 }
754 }
755 }
756
757 inline Node* NodeIter::operator*() { return graph_->FindNodeId(id_); }
758
759 inline Node* NodeIter::operator->() { return graph_->FindNodeId(id_); }
760
NeighborIter(EdgeSet::const_iterator iter,bool incoming)761 inline NeighborIter::NeighborIter(EdgeSet::const_iterator iter, bool incoming)
762 : iter_(iter), incoming_(incoming) {}
763
764 inline bool NeighborIter::operator==(const NeighborIter& rhs) {
765 return iter_ == rhs.iter_ && incoming_ == rhs.incoming_;
766 }
767
768 inline bool NeighborIter::operator!=(const NeighborIter& rhs) {
769 return !(*this == rhs);
770 }
771
772 inline void NeighborIter::operator++() { ++iter_; }
773
774 inline Node* NeighborIter::operator*() {
775 const Edge* e = *iter_;
776 return incoming_ ? e->src() : e->dst();
777 }
778
779 inline Node* NeighborIter::operator->() {
780 const Edge* e = *iter_;
781 return incoming_ ? e->src() : e->dst();
782 }
783
IsControlEdge()784 inline bool Edge::IsControlEdge() const {
785 // Note that if either src_output_ or dst_input_ is kControlSlot,
786 // so is the other one (AddEdge checks this).
787 return src_output_ == Graph::kControlSlot;
788 }
789
nodes()790 inline gtl::iterator_range<NodeIter> Graph::nodes() const {
791 // Note that NodeId 0 is always valid since we don't let the source
792 // node be removed from the graph.
793 return gtl::make_range(NodeIter(this, 0), NodeIter(this, num_node_ids()));
794 }
795
op_nodes()796 inline gtl::iterator_range<NodeIter> Graph::op_nodes() const {
797 // Note that NodeId 0 is always valid since we don't let the source
798 // node be removed from the graph.
799 //
800 // The current implementation of Graph maintains the invariant that the
801 // first two nodes are the source and sink nodes, and all other nodes are op
802 // nodes. This method (op_nodes()) relies on this invariant.
803 NodeIter begin(this, 0);
804 NodeIter end(this, num_node_ids());
805 if (begin != end) {
806 ++begin;
807 }
808 if (begin != end) {
809 ++begin;
810 }
811 return gtl::make_range(begin, end);
812 }
813
set_assigned_device_name_index(int index)814 inline void Node::set_assigned_device_name_index(int index) {
815 graph_->CheckDeviceNameIndex(index);
816 assigned_device_name_index_ = index;
817 }
818
set_assigned_device_name(const string & device_name)819 inline void Node::set_assigned_device_name(const string& device_name) {
820 graph_->set_assigned_device_name(this, device_name);
821 }
822
assigned_device_name()823 inline const string& Node::assigned_device_name() const {
824 return graph_->get_assigned_device_name(*this);
825 }
826
827 } // namespace tensorflow
828
829 #endif // TENSORFLOW_GRAPH_GRAPH_H_
830