1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // A Graph describes a set of computations that are to be
17 // performed, as well as the dependencies between those
18 // computations. The basic model is a DAG (directed acyclic graph) with
19 // * internal nodes representing computational operations to be performed;
20 // * edges represent dependencies, indicating the target may only be
21 // executed once the source has completed; and
22 // * predefined "source" (start) and "sink" (finish) nodes -- the source
23 // should be the only node that doesn't depend on anything, and the sink
24 // should be the only node that nothing depends on.
25 //
26 // Note: Node ids are intended to be relatively dense in the
27 // 0..max_id range, but there may be gaps since ids won't be reused.
28 //
29 // Note: Some dependencies between operations are due to one operation
30 // consuming the output of another. In fact operations can produce
31 // multiple outputs and consume multiple inputs, and some
32 // optimizations will care about which specific outputs are connected
33 // to which specific inputs. We therefore represent data dependency
34 // between output O of layer A and input I of layer B using
35 // "input index" and "output index" labels per edge.
36
37 #ifndef TENSORFLOW_CORE_GRAPH_GRAPH_H_
38 #define TENSORFLOW_CORE_GRAPH_GRAPH_H_
39
40 #include <functional>
41 #include <string>
42 #include <vector>
43 #include "tensorflow/core/framework/function.h"
44 #include "tensorflow/core/framework/op.h"
45 #include "tensorflow/core/framework/types.h"
46 #include "tensorflow/core/graph/edgeset.h"
47 #include "tensorflow/core/lib/core/arena.h"
48 #include "tensorflow/core/lib/core/refcount.h"
49 #include "tensorflow/core/lib/core/status.h"
50 #include "tensorflow/core/lib/gtl/iterator_range.h"
51 #include "tensorflow/core/platform/logging.h"
52 #include "tensorflow/core/platform/macros.h"
53 #include "tensorflow/core/platform/types.h"
54
55 namespace tensorflow {
56
57 class Edge;
58 class EdgeSetTest;
59 class Graph;
60 class GraphDef;
61 class Node;
62 struct OutputTensor;
63 class VersionDef;
64 class WhileContext;
65
66 class NeighborIter; // Declared below
67 class NodeIter; // Declared below
68 struct NodeProperties; // Defined in .cc
69
70 class Node {
71 public:
72 string DebugString() const;
id()73 int id() const { return id_; }
cost_id()74 int cost_id() const { return cost_id_; }
75 const string& name() const;
76 void set_name(string name);
77 const string& type_string() const;
78
79 // def() provides the NodeDef the user supplied, but the specifics
80 // of this Node may have changed due to placement, optimization, etc.
81 // In particular:
82 // * def().name() will match name();
83 // * def().op() will match type_string() and op_def().name();
84 // * def().input() is not reliable, use "in_edges()" below instead;
85 // * def().device() is the "user's requested device" and may not match
86 // the actual assigned device, see assigned_device_name() below;
87 // * def().attr() is authoritative.
88 // TODO(irving): Replace with NodeInfo.
89 const NodeDef& def() const;
90 const OpDef& op_def() const;
91
92 // input and output types
93 int32 num_inputs() const;
94 DataType input_type(int32 i) const;
95 const DataTypeVector& input_types() const;
96
97 int32 num_outputs() const;
98 DataType output_type(int32 o) const;
99 const DataTypeVector& output_types() const;
100
101 // The device requested by the user. For the actual assigned device,
102 // use assigned_device_name() below.
103 const string& requested_device() const;
104
105 // This changes the user requested device but not necessarily the device that
106 // on which the operation will run.
107 void set_requested_device(const string& device);
108
109 // This gives the device the runtime has assigned this node to. If
110 // you want the device the user requested, use def().device() instead.
111 // TODO(josh11b): Validate that the assigned_device, if not empty:
112 // fully specifies a device, and satisfies def().device().
113 // TODO(josh11b): Move assigned_device_name outside of Node into a
114 // NodeId->DeviceName map.
115 const string& assigned_device_name() const;
116 void set_assigned_device_name(const string& device_name);
has_assigned_device_name()117 bool has_assigned_device_name() const {
118 return assigned_device_name_index_ > 0;
119 }
assigned_device_name_index()120 int assigned_device_name_index() const { return assigned_device_name_index_; }
121 void set_assigned_device_name_index(int index);
122
123 // Sets 'original_node_names' field of this node's DebugInfo proto to
124 // 'names'.
125 void set_original_node_names(const std::vector<string>& names);
126
127 // Read only access to attributes
128 AttrSlice attrs() const;
129
130 // Inputs requested by the NodeDef. For the actual inputs, use in_edges.
131 const protobuf::RepeatedPtrField<string>& requested_inputs() const;
132
133 // Get the neighboring nodes via edges either in or out of this node. This
134 // includes control edges.
135 gtl::iterator_range<NeighborIter> in_nodes() const;
136 gtl::iterator_range<NeighborIter> out_nodes() const;
in_edges()137 const EdgeSet& in_edges() const { return in_edges_; }
out_edges()138 const EdgeSet& out_edges() const { return out_edges_; }
139
140 // Node type helpers.
IsSource()141 bool IsSource() const { return id() == 0; }
IsSink()142 bool IsSink() const { return id() == 1; }
143 // Anything other than the special Source & Sink nodes.
IsOp()144 bool IsOp() const { return id() > 1; }
145
146 // Node class helpers
IsSwitch()147 bool IsSwitch() const { return class_ == NC_SWITCH; }
IsMerge()148 bool IsMerge() const { return class_ == NC_MERGE; }
IsEnter()149 bool IsEnter() const { return class_ == NC_ENTER; }
IsExit()150 bool IsExit() const { return class_ == NC_EXIT; }
IsNextIteration()151 bool IsNextIteration() const { return class_ == NC_NEXT_ITERATION; }
IsLoopCond()152 bool IsLoopCond() const { return class_ == NC_LOOP_COND; }
IsControlTrigger()153 bool IsControlTrigger() const { return class_ == NC_CONTROL_TRIGGER; }
IsSend()154 bool IsSend() const { return class_ == NC_SEND || class_ == NC_HOST_SEND; }
IsRecv()155 bool IsRecv() const { return class_ == NC_RECV || class_ == NC_HOST_RECV; }
IsConstant()156 bool IsConstant() const { return class_ == NC_CONSTANT; }
IsVariable()157 bool IsVariable() const { return class_ == NC_VARIABLE; }
IsIdentity()158 bool IsIdentity() const { return class_ == NC_IDENTITY; }
IsGetSessionHandle()159 bool IsGetSessionHandle() const { return class_ == NC_GET_SESSION_HANDLE; }
IsGetSessionTensor()160 bool IsGetSessionTensor() const { return class_ == NC_GET_SESSION_TENSOR; }
IsDeleteSessionTensor()161 bool IsDeleteSessionTensor() const {
162 return class_ == NC_DELETE_SESSION_TENSOR;
163 }
IsControlFlow()164 bool IsControlFlow() const {
165 return (class_ != NC_OTHER) && // Fast path
166 (IsSwitch() || IsMerge() || IsEnter() || IsExit() ||
167 IsNextIteration());
168 }
IsHostSend()169 bool IsHostSend() const { return class_ == NC_HOST_SEND; }
IsHostRecv()170 bool IsHostRecv() const { return class_ == NC_HOST_RECV; }
IsScopedAllocator()171 bool IsScopedAllocator() const { return class_ == NC_SCOPED_ALLOCATOR; }
IsCollective()172 bool IsCollective() const { return class_ == NC_COLLECTIVE; }
173
IsMetadata()174 bool IsMetadata() const { return class_ == NC_METADATA; }
IsFakeParam()175 bool IsFakeParam() const { return class_ == NC_FAKE_PARAM; }
IsPartitionedCall()176 bool IsPartitionedCall() const { return class_ == NC_PARTITIONED_CALL; }
177 // Is this node a function input
IsArg()178 bool IsArg() const { return class_ == NC_ARG; }
179 // Is this node a function output
IsRetval()180 bool IsRetval() const { return class_ == NC_RETVAL; }
181
182 template <typename T>
AddAttr(const string & name,const T & val)183 void AddAttr(const string& name, const T& val) {
184 SetAttrValue(val, AddAttrHelper(name));
185 UpdateProperties();
186 }
187
188 void ClearAttr(const string& name);
189
190 // Returns into '*e' the edge connecting to the 'idx' input of this Node.
191 Status input_edge(int idx, const Edge** e) const;
192
193 // Returns into '*edges' the input data edges of this Node, indexed by input
194 // number. Does not return control edges.
195 Status input_edges(std::vector<const Edge*>* edges) const;
196
197 // Returns into '*n' the node that has an output connected to the
198 // 'idx' input of this Node.
199 Status input_node(int idx, const Node** n) const;
200 Status input_node(int idx, Node** n) const;
201
202 // Returns into '*t' the idx-th input tensor of this node, represented as the
203 // output tensor of input_node(idx).
204 Status input_tensor(int idx, OutputTensor* t) const;
205
while_ctx()206 WhileContext* while_ctx() const { return while_ctx_; }
set_while_ctx(WhileContext * while_ctx)207 void set_while_ctx(WhileContext* while_ctx) {
208 DCHECK(IsExit());
209 DCHECK(while_ctx_ == nullptr);
210 while_ctx_ = while_ctx;
211 }
212
213 private:
214 friend class Graph;
215 Node();
216
properties()217 NodeProperties* properties() const { return props_.get(); }
218
219 void Initialize(int id, int cost_id, std::shared_ptr<NodeProperties> props);
220
221 // Releases memory from props_, in addition to restoring *this to its
222 // uninitialized state.
223 void Clear();
224
225 // Make a copy of the Node's props_ if props_ is shared with
226 // other nodes. This must be called before mutating properties,
227 // e.g. in AddAttr.
228 void MaybeCopyOnWrite();
229
230 // Called after an attr has changed. Decides whether we need to update some
231 // property of the node (stored in props_).
232 void UpdateProperties();
233
234 AttrValue* AddAttrHelper(const string& name);
235
236 // A set of mutually exclusive classes for different kinds of nodes,
237 // class_ is initialized in the Node::Initialize routine based on the
238 // node's type_string().
239 enum NodeClass {
240 NC_UNINITIALIZED,
241 NC_SWITCH,
242 NC_MERGE,
243 NC_ENTER,
244 NC_EXIT,
245 NC_NEXT_ITERATION,
246 NC_LOOP_COND,
247 NC_CONTROL_TRIGGER,
248 NC_SEND,
249 NC_HOST_SEND,
250 NC_RECV,
251 NC_HOST_RECV,
252 NC_CONSTANT,
253 NC_VARIABLE,
254 NC_IDENTITY,
255 NC_GET_SESSION_HANDLE,
256 NC_GET_SESSION_TENSOR,
257 NC_DELETE_SESSION_TENSOR,
258 NC_METADATA,
259 NC_SCOPED_ALLOCATOR,
260 NC_COLLECTIVE,
261 NC_FAKE_PARAM,
262 NC_PARTITIONED_CALL,
263 NC_ARG,
264 NC_RETVAL,
265 NC_OTHER // Not a special kind of node
266 };
267
268 static const std::unordered_map<string, NodeClass>& kNodeClassTable;
269
270 static NodeClass GetNodeClassForOp(const string& ts);
271
272 int id_; // -1 until Initialize() is called
273 int cost_id_; // -1 if there is no corresponding cost accounting node
274 NodeClass class_;
275
276 EdgeSet in_edges_;
277 EdgeSet out_edges_;
278
279 // NOTE(skyewm): inheriting from core::RefCounted may have a slight
280 // performance benefit over using shared_ptr, at the cost of manual ref
281 // counting
282 std::shared_ptr<NodeProperties> props_;
283
284 // Index within Graph::device_names_ of the name of device assigned
285 // to perform this computation.
286 int assigned_device_name_index_;
287
288 // A back-pointer to the Graph that owns this node. Currently, this exists
289 // solely to allow Node::[set_]assigned_device_name() to work. However, if all
290 // callers of Node::[set_]assigned_device_name() are modified to use the
291 // equivalent methods defined directly on Graph, then we can remove this
292 // field and reclaim that memory.
293 Graph* graph_;
294
295 // Set if this is an exit node of a while loop with an associated
296 // WhileContext. Otherwise null. (This is only set for exit nodes because
297 // they're the first nodes of a loop encountered while creating the gradient
298 // graph. Exit nodes that are part of while loop gradient graphs will not have
299 // this set.)
300 WhileContext* while_ctx_;
301
302 TF_DISALLOW_COPY_AND_ASSIGN(Node);
303 };
304
305 // Stores debug information associated with the Node.
306 struct NodeDebugInfo {
307 const string name;
308 std::vector<string> original_node_names;
309
310 NodeDebugInfo(const Node& n);
311 NodeDebugInfo(const NodeDef& ndef);
312 };
313
314 // Represents an input of a node, i.e., the `index`-th input to `node`.
315 struct InputTensor {
316 Node* node;
317 int index;
318
InputTensorInputTensor319 InputTensor(Node* n, int i) : node(n), index(i) {}
InputTensorInputTensor320 InputTensor() : node(nullptr), index(0) {}
321
322 // Returns true if this InputTensor is identical to 'other'. Nodes are
323 // compared using pointer equality.
324 bool operator==(const InputTensor& other) const;
325
326 // A hash function for InputTensors. Nodes are hashed based on their pointer
327 // value.
328 struct Hash {
329 uint64 operator()(InputTensor const& s) const;
330 };
331 };
332
333 // Represents an output of a node, i.e., the `index`-th output of `node`. Note
334 // that a single `OutputTensor` can correspond to multiple `Edge`s if the output
335 // is consumed by multiple destination nodes.
336 struct OutputTensor {
337 Node* node;
338 int index;
339
OutputTensorOutputTensor340 OutputTensor(Node* n, int i) : node(n), index(i) {}
OutputTensorOutputTensor341 OutputTensor() : node(nullptr), index(0) {}
342
343 // Returns true if this OutputTensor is identical to 'other'. Nodes are
344 // compared using pointer equality.
345 bool operator==(const OutputTensor& other) const;
346
347 // A hash function for OutputTensors. Nodes are hashed based on their pointer
348 // value.
349 struct Hash {
350 uint64 operator()(OutputTensor const& s) const;
351 };
352 };
353
354 class Edge {
355 public:
src()356 Node* src() const { return src_; }
dst()357 Node* dst() const { return dst_; }
id()358 int id() const { return id_; }
359
360 // Return the index of the source output that produces the data
361 // carried by this edge. The special value kControlSlot is used
362 // for control dependencies.
src_output()363 int src_output() const { return src_output_; }
364
365 // Return the index of the destination input that consumes the data
366 // carried by this edge. The special value kControlSlot is used
367 // for control dependencies.
dst_input()368 int dst_input() const { return dst_input_; }
369
370 // Return true iff this is an edge that indicates a control-flow
371 // (as opposed to a data-flow) dependency.
372 bool IsControlEdge() const;
373
374 string DebugString() const;
375
376 private:
Edge()377 Edge() {}
378
379 friend class EdgeSetTest;
380 friend class Graph;
381 Node* src_;
382 Node* dst_;
383 int id_;
384 int src_output_;
385 int dst_input_;
386 };
387
388 // Allows for iteration of the edges of a Graph, by iterating the underlying
389 // Graph.edges_ vector while skipping over null entries.
390 class GraphEdgesIterable {
391 private:
392 const std::vector<Edge*>& edges_;
393
394 public:
GraphEdgesIterable(const std::vector<Edge * > & edges)395 explicit GraphEdgesIterable(const std::vector<Edge*>& edges)
396 : edges_(edges) {}
397
398 typedef Edge* value_type;
399
400 class const_iterator {
401 private:
402 // The underlying iterator.
403 std::vector<value_type>::const_iterator iter_;
404
405 // The end of the underlying iterator.
406 std::vector<value_type>::const_iterator end_;
407
408 // Advances iter_ until it reaches a non-null item, or reaches the end.
apply_filter()409 void apply_filter() {
410 while (iter_ != end_ && *iter_ == nullptr) {
411 ++iter_;
412 }
413 }
414
415 public:
const_iterator(std::vector<value_type>::const_iterator iter,std::vector<value_type>::const_iterator end)416 const_iterator(std::vector<value_type>::const_iterator iter,
417 std::vector<value_type>::const_iterator end)
418 : iter_(iter), end_(end) {
419 apply_filter();
420 }
421
422 bool operator==(const const_iterator& other) const {
423 return iter_ == other.iter_;
424 }
425
426 bool operator!=(const const_iterator& other) const {
427 return iter_ != other.iter_;
428 }
429
430 // This is the prefix increment operator (++x), which is the operator
431 // used by C++ range iteration (for (x : y) ...). We intentionally do not
432 // provide a postfix increment operator.
433 const_iterator& operator++() {
434 ++iter_;
435 apply_filter();
436 return *this;
437 }
438
439 value_type operator*() { return *iter_; }
440 };
441
begin()442 const_iterator begin() {
443 return const_iterator(edges_.begin(), edges_.end());
444 }
end()445 const_iterator end() { return const_iterator(edges_.end(), edges_.end()); }
446 };
447
448 // Thread compatible but not thread safe.
449 class Graph {
450 public:
451 // Constructs a graph with a single SOURCE (always id kSourceId) and a
452 // single SINK (always id kSinkId) node, and an edge from SOURCE->SINK.
453 //
454 // The graph can hold ops found in the registry. `ops`s lifetime must be at
455 // least that of the constructed graph's.
456 explicit Graph(const OpRegistryInterface* ops);
457
458 // Constructs a graph with a single SOURCE (always id kSourceId) and a
459 // single SINK (always id kSinkId) node, and an edge from SOURCE->SINK.
460 //
461 // The graph can hold ops found in `flib_def`. Unlike the constructor taking
462 // an OpRegistryInterface, this constructor copies the function definitions in
463 // `flib_def` so its lifetime may be shorter than that of the graph's. The
464 // OpRegistryInterface backing `flib_def` must still have the lifetime of the
465 // graph though.
466 explicit Graph(const FunctionLibraryDefinition& flib_def);
467
468 ~Graph();
469
470 static const int kControlSlot;
471
472 // The GraphDef version range of this graph (see graph.proto).
473 const VersionDef& versions() const;
474 void set_versions(const VersionDef& versions);
475
476 // Adds a new node to this graph, and returns it. Infers the Op and
477 // input/output types for the node. *this owns the returned instance.
478 // Returns nullptr and sets *status on error.
479 Node* AddNode(const NodeDef& node_def, Status* status);
480
481 // Copies *node, which may belong to another graph, to a new node,
482 // which is returned. Does not copy any edges. *this owns the
483 // returned instance.
484 Node* CopyNode(const Node* node);
485
486 // Removes a node from this graph, including all edges from or to it.
487 // *node should not be accessed after calling this function.
488 // REQUIRES: node->IsOp()
489 void RemoveNode(Node* node);
490
491 // Adds an edge that connects the xth output of `source` to the yth input of
492 // `dest` and returns it. Does not update dest's NodeDef.
493 const Edge* AddEdge(Node* source, int x, Node* dest, int y);
494
495 // Adds a control edge (no data flows along this edge) that connects `source`
496 // to `dest`. If `dest`s NodeDef is missing the corresponding control input,
497 // adds the control input.
498 //
499 // If such a control edge already exists and `allow_duplicates` is false, no
500 // edge is added and the function returns nullptr. Otherwise the edge is
501 // unconditionally created and returned. The NodeDef is not updated if
502 // `allow_duplicates` is true.
503 // TODO(skyewm): // TODO(skyewm): allow_duplicates is needed only by
504 // graph_partition.cc. Figure out if we can do away with it.
505 const Edge* AddControlEdge(Node* source, Node* dest,
506 bool allow_duplicates = false);
507
508 // Removes edge from the graph. Does not update the destination node's
509 // NodeDef.
510 // REQUIRES: The edge must exist.
511 void RemoveEdge(const Edge* edge);
512
513 // Removes control edge `edge` from the graph. Note that this also updates
514 // the corresponding NodeDef to reflect the change.
515 // REQUIRES: The control edge must exist.
516 void RemoveControlEdge(const Edge* e);
517
518 // Updates the input to a node. The existing edge to `dst` is removed and an
519 // edge from `new_src` to `dst` is created. The NodeDef associated with `dst`
520 // is also updated.
521 Status UpdateEdge(Node* new_src, int new_src_index, Node* dst, int dst_index);
522
523 // Like AddEdge but updates dst's NodeDef. Used to add an input edge to a
524 // "While" op during gradient construction, see AddInputWhileHack in
525 // python_api.h for more details.
526 Status AddWhileInputHack(Node* new_src, int new_src_index, Node* dst);
527
528 // Adds the function and gradient definitions in `fdef_lib` to this graph's op
529 // registry. Ignores duplicate functions, and returns a bad status if an
530 // imported function differs from an existing function or op with the same
531 // name.
532 Status AddFunctionLibrary(const FunctionDefLibrary& fdef_lib);
533
534 // The number of live nodes in the graph.
535 //
536 // Because nodes can be removed from the graph, num_nodes() is often
537 // smaller than num_node_ids(). If one needs to create an array of
538 // nodes indexed by node ids, num_node_ids() should be used as the
539 // array's size.
num_nodes()540 int num_nodes() const { return num_nodes_; }
541
542 // The number of live nodes in the graph, excluding the Source and Sink nodes.
num_op_nodes()543 int num_op_nodes() const {
544 DCHECK_GE(num_nodes_, 2);
545 return num_nodes_ - 2;
546 }
547
548 // The number of live edges in the graph.
549 //
550 // Because edges can be removed from the graph, num_edges() is often
551 // smaller than num_edge_ids(). If one needs to create an array of
552 // edges indexed by edge ids, num_edge_ids() should be used as the
553 // array's size.
num_edges()554 int num_edges() const { return num_edges_; }
555
556 // Serialize the nodes starting at `from_node_id` to a GraphDef.
557 void ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const;
558
559 // Serialize to a GraphDef.
560 void ToGraphDef(GraphDef* graph_def) const;
561
562 // This version can be called from debugger to inspect the graph content.
563 // Use the previous version outside debug context for efficiency reasons.
564 //
565 // Note: We do not expose a DebugString() API, since GraphDef.DebugString() is
566 // not defined in some TensorFlow builds.
567 GraphDef ToGraphDefDebug() const;
568
569 // Generate new node name with the specified prefix that is unique
570 // across this graph.
571 string NewName(StringPiece prefix);
572
573 // Access to the list of all nodes. Example usage:
574 // for (Node* node : graph.nodes()) { ... }
575 gtl::iterator_range<NodeIter> nodes() const;
576
577 // Access to the list of all nodes, excluding the Source and Sink nodes.
578 gtl::iterator_range<NodeIter> op_nodes() const;
579
580 // Returns one more than the maximum id assigned to any node.
num_node_ids()581 int num_node_ids() const { return nodes_.size(); }
582
583 // Returns the node associated with an id, or nullptr if no node
584 // with that id (the node with that id was removed and the id has
585 // not yet been re-used). *this owns the returned instance.
586 // REQUIRES: 0 <= id < num_node_ids().
FindNodeId(int id)587 Node* FindNodeId(int id) const { return nodes_[id]; }
588
589 // Returns one more than the maximum id assigned to any edge.
num_edge_ids()590 int num_edge_ids() const { return edges_.size(); }
591
592 // Returns the Edge associated with an id, or nullptr if no edge
593 // with that id (the node with that id was removed and the id has
594 // not yet been re-used). *this owns the returned instance.
595 // REQUIRES: 0 <= id < num_node_ids().
FindEdgeId(int id)596 const Edge* FindEdgeId(int id) const { return edges_[id]; }
597
598 // Access to the set of all edges. Example usage:
599 // for (const Edge* e : graph.edges()) { ... }
edges()600 GraphEdgesIterable edges() const { return GraphEdgesIterable(edges_); }
601
602 // The pre-defined nodes.
603 enum { kSourceId = 0, kSinkId = 1 };
source_node()604 Node* source_node() const { return FindNodeId(kSourceId); }
sink_node()605 Node* sink_node() const { return FindNodeId(kSinkId); }
606
op_registry()607 const OpRegistryInterface* op_registry() const { return &ops_; }
flib_def()608 const FunctionLibraryDefinition& flib_def() const { return ops_; }
609
CheckDeviceNameIndex(int index)610 void CheckDeviceNameIndex(int index) {
611 DCHECK_GE(index, 0);
612 DCHECK_LT(index, static_cast<int>(device_names_.size()));
613 }
614
615 int InternDeviceName(const string& device_name);
616
get_assigned_device_name(const Node & node)617 const string& get_assigned_device_name(const Node& node) const {
618 return device_names_[node.assigned_device_name_index()];
619 }
620
set_assigned_device_name_index(Node * node,int device_name_index)621 void set_assigned_device_name_index(Node* node, int device_name_index) {
622 CheckDeviceNameIndex(device_name_index);
623 node->assigned_device_name_index_ = device_name_index;
624 }
625
set_assigned_device_name(Node * node,const string & device_name)626 void set_assigned_device_name(Node* node, const string& device_name) {
627 node->assigned_device_name_index_ = InternDeviceName(device_name);
628 }
629
630 // Returns OK if `node` is non-null and belongs to this graph
631 Status IsValidNode(const Node* node) const;
632
633 // Returns OK if IsValidNode(`node`) and `idx` is a valid output. Does not
634 // accept control outputs.
635 Status IsValidOutputTensor(const Node* node, int idx) const;
636
637 // Returns OK if IsValidNode(`node`) and `idx` a valid input. Does not accept
638 // control inputs.
639 Status IsValidInputTensor(const Node* node, int idx) const;
640
641 // Create and return a new WhileContext owned by this graph. This is called
642 // when a new while loop is created. `frame_name` must be unique among
643 // WhileContexts in this graph.
644 Status AddWhileContext(StringPiece frame_name, std::vector<Node*> enter_nodes,
645 std::vector<Node*> exit_nodes,
646 OutputTensor cond_output,
647 std::vector<OutputTensor> body_inputs,
648 std::vector<OutputTensor> body_outputs,
649 WhileContext** result);
650
651 // Builds a node name to node pointer index for all nodes in the graph.
652 std::unordered_map<string, Node*> BuildNodeNameIndex() const;
653
654 // TODO(josh11b): uint64 hash() const;
655
656 private:
657 // If cost_node is non-null, then cost accounting (in CostModel)
658 // will be associated with that node rather than the new one being
659 // created.
660 //
661 // Ownership of the returned Node is not transferred to caller.
662 Node* AllocateNode(std::shared_ptr<NodeProperties> props,
663 const Node* cost_node);
664 void ReleaseNode(Node* node);
665
666 // Registry of all known ops, including functions.
667 FunctionLibraryDefinition ops_;
668
669 // GraphDef versions
670 const std::unique_ptr<VersionDef> versions_;
671
672 // Allocator which will give us good locality.
673 core::Arena arena_;
674
675 // Map from node ids to allocated nodes. nodes_[id] may be nullptr if
676 // the node with that id was removed from the graph.
677 std::vector<Node*> nodes_;
678
679 // Number of nodes alive.
680 int64 num_nodes_ = 0;
681
682 // Map from edge ids to allocated edges. edges_[id] may be nullptr if
683 // the edge with that id was removed from the graph.
684 std::vector<Edge*> edges_;
685
686 // The number of entries in edges_ that are not nullptr.
687 int num_edges_ = 0;
688
689 // Allocated but free nodes and edges.
690 std::vector<Node*> free_nodes_;
691 std::vector<Edge*> free_edges_;
692
693 // For generating unique names.
694 int name_counter_ = 0;
695
696 // In most graphs, the number of unique values used for the
697 // Node::assigned_device_name() property is quite small. If the graph is
698 // large, then this duplication of values can consume a significant amount of
699 // memory. Instead, we represent the same information using an interning
700 // table, which consists of a vector of unique strings (device_names_), as
701 // well a map (device_names_map_) from unique strings to indices within the
702 // unique string table.
703 //
704 // The InternDeviceName() method handles adding a new entry into the table,
705 // or locating the index of an existing entry.
706 //
707 // The fact that Node::assigned_device_name() is implemented using an
708 // interning table is intentionally public. This allows algorithms that
709 // frequently access this field to do so efficiently, especially for the case
710 // where the assigned_device_name of one Node is copied directly from that
711 // of another Node.
712
713 // A table of the unique assigned device names. Indices do NOT correspond
714 // to node IDs. Index 0 is always the empty string.
715 std::vector<string> device_names_;
716
717 // Maps unique device names to indices within device_names_[i].
718 std::unordered_map<string, int> device_names_map_;
719
720 // All the while contexts owned by this graph, keyed by frame name,
721 // corresponding to all the while loops contained in this graph (including
722 // nested loops). The stored contexts are usually accessed via
723 // AddWhileContext() or Node::while_ctx(), but this manages the lifetime.
724 std::map<string, WhileContext> while_ctxs_;
725
726 TF_DISALLOW_COPY_AND_ASSIGN(Graph);
727 };
728
729 // TODO(josh11b): We may want to support keeping an index on various
730 // node/edge attributes in a graph, particularly node names.
731
732 // Helper routines
733
IsSource(const Node * node)734 inline bool IsSource(const Node* node) { return node->IsSource(); }
IsSink(const Node * node)735 inline bool IsSink(const Node* node) { return node->IsSink(); }
IsSwitch(const Node * node)736 inline bool IsSwitch(const Node* node) { return node->IsSwitch(); }
IsMerge(const Node * node)737 inline bool IsMerge(const Node* node) { return node->IsMerge(); }
IsEnter(const Node * node)738 inline bool IsEnter(const Node* node) { return node->IsEnter(); }
IsExit(const Node * node)739 inline bool IsExit(const Node* node) { return node->IsExit(); }
IsNextIteration(const Node * n)740 inline bool IsNextIteration(const Node* n) { return n->IsNextIteration(); }
IsLoopCond(const Node * node)741 inline bool IsLoopCond(const Node* node) { return node->IsLoopCond(); }
IsControlTrigger(const Node * n)742 inline bool IsControlTrigger(const Node* n) { return n->IsControlTrigger(); }
IsSend(const Node * node)743 inline bool IsSend(const Node* node) { return node->IsSend(); }
IsRecv(const Node * node)744 inline bool IsRecv(const Node* node) { return node->IsRecv(); }
IsHostSend(const Node * node)745 inline bool IsHostSend(const Node* node) { return node->IsHostSend(); }
IsHostRecv(const Node * node)746 inline bool IsHostRecv(const Node* node) { return node->IsHostRecv(); }
747
748 // True for Nodes that mediate the transfer of values between processes.
IsTransferNode(const Node * n)749 inline bool IsTransferNode(const Node* n) { return IsSend(n) || IsRecv(n); }
750
IsConstant(const Node * node)751 inline bool IsConstant(const Node* node) { return node->IsConstant(); }
IsVariable(const Node * node)752 inline bool IsVariable(const Node* node) { return node->IsVariable(); }
IsIdentity(const Node * node)753 inline bool IsIdentity(const Node* node) { return node->IsIdentity(); }
754
755 // Returns true iff 'n' is a control flow node.
IsControlFlow(const Node * n)756 inline bool IsControlFlow(const Node* n) { return n->IsControlFlow(); }
757
758 // Returns true if the node only depends on its input's metadata
759 // (shape). Specifically, returns true for "Size", "Shape" and "Rank" ops.
IsMetadata(const Node * n)760 inline bool IsMetadata(const Node* n) { return n->IsMetadata(); }
761
IsScopedAllocator(const Node * n)762 inline bool IsScopedAllocator(const Node* n) { return n->IsScopedAllocator(); }
763
IsHostMemoryPreserving(const Node * node)764 inline bool IsHostMemoryPreserving(const Node* node) {
765 return IsIdentity(node) || IsControlFlow(node);
766 }
767
768 // Iterator for stepping through the nodes of a graph.
769 class NodeIter {
770 public:
771 NodeIter(const Graph* graph, int id);
772 bool operator==(const NodeIter& rhs);
773 bool operator!=(const NodeIter& rhs);
774 void operator++();
775 Node* operator*();
776 Node* operator->();
777
778 private:
779 // Invariant: id_ == graph_->num_node_ids() || graph_->FindId(id_) != nullptr
780 const Graph* graph_;
781 int id_;
782 };
783
784 // Iterator for stepping through the neighbors of a node.
785 class NeighborIter {
786 public:
787 NeighborIter(EdgeSet::const_iterator iter, bool incoming);
788 bool operator==(const NeighborIter& rhs);
789 bool operator!=(const NeighborIter& rhs);
790 void operator++();
791 Node* operator*();
792 Node* operator->();
793
794 private:
795 EdgeSet::const_iterator iter_;
796 bool incoming_;
797 };
798
799 // IMPLEMENTATION DETAILS, PLEASE IGNORE
800
NodeIter(const Graph * graph,int id)801 inline NodeIter::NodeIter(const Graph* graph, int id)
802 : graph_(graph), id_(id) {}
803
804 inline bool NodeIter::operator==(const NodeIter& rhs) {
805 DCHECK(graph_ == rhs.graph_);
806 return id_ == rhs.id_;
807 }
808
809 inline bool NodeIter::operator!=(const NodeIter& rhs) {
810 return !(*this == rhs);
811 }
812
813 inline void NodeIter::operator++() {
814 while (1) {
815 DCHECK_LE(id_, graph_->num_node_ids());
816 ++id_;
817 if (id_ >= graph_->num_node_ids() || graph_->FindNodeId(id_) != nullptr) {
818 return;
819 }
820 }
821 }
822
823 inline Node* NodeIter::operator*() { return graph_->FindNodeId(id_); }
824
825 inline Node* NodeIter::operator->() { return graph_->FindNodeId(id_); }
826
NeighborIter(EdgeSet::const_iterator iter,bool incoming)827 inline NeighborIter::NeighborIter(EdgeSet::const_iterator iter, bool incoming)
828 : iter_(iter), incoming_(incoming) {}
829
830 inline bool NeighborIter::operator==(const NeighborIter& rhs) {
831 return iter_ == rhs.iter_ && incoming_ == rhs.incoming_;
832 }
833
834 inline bool NeighborIter::operator!=(const NeighborIter& rhs) {
835 return !(*this == rhs);
836 }
837
838 inline void NeighborIter::operator++() { ++iter_; }
839
840 inline Node* NeighborIter::operator*() {
841 const Edge* e = *iter_;
842 return incoming_ ? e->src() : e->dst();
843 }
844
845 inline Node* NeighborIter::operator->() {
846 const Edge* e = *iter_;
847 return incoming_ ? e->src() : e->dst();
848 }
849
IsControlEdge()850 inline bool Edge::IsControlEdge() const {
851 // Note that if either src_output_ or dst_input_ is kControlSlot,
852 // so is the other one (AddEdge checks this).
853 return src_output_ == Graph::kControlSlot;
854 }
855
nodes()856 inline gtl::iterator_range<NodeIter> Graph::nodes() const {
857 // Note that NodeId 0 is always valid since we don't let the source
858 // node be removed from the graph.
859 return gtl::make_range(NodeIter(this, 0), NodeIter(this, num_node_ids()));
860 }
861
op_nodes()862 inline gtl::iterator_range<NodeIter> Graph::op_nodes() const {
863 // Note that NodeId 0 is always valid since we don't let the source
864 // node be removed from the graph.
865 //
866 // The current implementation of Graph maintains the invariant that the
867 // first two nodes are the source and sink nodes, and all other nodes are op
868 // nodes. This method (op_nodes()) relies on this invariant.
869 NodeIter begin(this, 0);
870 NodeIter end(this, num_node_ids());
871 if (begin != end) {
872 ++begin;
873 }
874 if (begin != end) {
875 ++begin;
876 }
877 return gtl::make_range(begin, end);
878 }
879
set_assigned_device_name_index(int index)880 inline void Node::set_assigned_device_name_index(int index) {
881 graph_->CheckDeviceNameIndex(index);
882 assigned_device_name_index_ = index;
883 }
884
set_assigned_device_name(const string & device_name)885 inline void Node::set_assigned_device_name(const string& device_name) {
886 graph_->set_assigned_device_name(this, device_name);
887 }
888
assigned_device_name()889 inline const string& Node::assigned_device_name() const {
890 return graph_->get_assigned_device_name(*this);
891 }
892
893 } // namespace tensorflow
894
895 #endif // TENSORFLOW_CORE_GRAPH_GRAPH_H_
896