1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // A Graph describes a set of computations that are to be
17 // performed, as well as the dependencies between those
18 // computations. The basic model is a DAG (directed acyclic graph) with
19 // * internal nodes representing computational operations to be performed;
20 // * edges represent dependencies, indicating the target may only be
21 // executed once the source has completed; and
22 // * predefined "source" (start) and "sink" (finish) nodes -- the source
23 // should be the only node that doesn't depend on anything, and the sink
24 // should be the only node that nothing depends on.
25 //
26 // Note: Node ids are intended to be relatively dense in the
27 // 0..max_id range, but there may be gaps since ids won't be reused.
28 //
29 // Note: Some dependencies between operations are due to one operation
30 // consuming the output of another. In fact operations can produce
31 // multiple outputs and consume multiple inputs, and some
32 // optimizations will care about which specific outputs are connected
33 // to which specific inputs. We therefore represent data dependency
34 // between output O of layer A and input I of layer B using
35 // "input index" and "output index" labels per edge.
36
37 #ifndef TENSORFLOW_CORE_GRAPH_GRAPH_H_
38 #define TENSORFLOW_CORE_GRAPH_GRAPH_H_
39
40 #include <functional>
41 #include <string>
42 #include <vector>
43
44 #include "absl/types/optional.h"
45 #include "tensorflow/core/framework/function.h"
46 #include "tensorflow/core/framework/node_def.pb.h"
47 #include "tensorflow/core/framework/node_def_util.h"
48 #include "tensorflow/core/framework/op.h"
49 #include "tensorflow/core/framework/types.h"
50 #include "tensorflow/core/graph/edgeset.h"
51 #include "tensorflow/core/lib/core/arena.h"
52 #include "tensorflow/core/lib/core/refcount.h"
53 #include "tensorflow/core/lib/core/status.h"
54 #include "tensorflow/core/lib/gtl/iterator_range.h"
55 #include "tensorflow/core/platform/logging.h"
56 #include "tensorflow/core/platform/macros.h"
57 #include "tensorflow/core/platform/types.h"
58
59 namespace tensorflow {
60
61 class Edge;
62 class EdgeSetTest;
63 class Graph;
64 class GraphDef;
65 class Node;
66 struct OutputTensor;
67 class VersionDef;
68 class WhileContext;
69
70 class NeighborIter; // Declared below
71 class NodeIter; // Declared below
72
73 // Indicates where the graph instance is originated from.
74 enum class ConstructionContext {
75 kNotTracked, // Not tracked.
76 kDirectSession, // From `tensorflow::DirectSession`, TF1 session API.
77 kEagerRuntime, // Registered from TF2 eager runtime.
78 };
79
80 class Node {
81 public:
82 std::string DebugString() const;
id()83 int id() const { return id_; }
cost_id()84 int cost_id() const { return cost_id_; }
85 const std::string& name() const;
86 void set_name(std::string name);
87 const std::string& type_string() const;
88
89 // def() provides the NodeDef the user supplied, but the specifics
90 // of this Node may have changed due to placement, optimization, etc.
91 // In particular:
92 // * def().name() will match name();
93 // * def().op() will match type_string() and op_def().name();
94 // * def().input() is not reliable, use "in_edges()" below instead;
95 // * def().device() is the "user's requested device" and may not match
96 // the actual assigned device, see assigned_device_name() below;
97 // * def().attr() is authoritative.
98 // TODO(irving): Replace with NodeInfo.
99 const NodeDef& def() const;
100 const OpDef& op_def() const;
101
102 // input and output types
103 int32 num_inputs() const;
104 DataType input_type(int32 i) const;
105 const DataTypeVector& input_types() const;
106
107 int32 num_outputs() const;
108 DataType output_type(int32 o) const;
109 const DataTypeVector& output_types() const;
110
111 // The device requested by the user. For the actual assigned device,
112 // use assigned_device_name() below.
113 const std::string& requested_device() const;
114
115 // This changes the user requested device but not necessarily the device that
116 // on which the operation will run.
117 void set_requested_device(const std::string& device);
118
119 // This gives the device the runtime has assigned this node to. If
120 // you want the device the user requested, use def().device() instead.
121 // TODO(josh11b): Validate that the assigned_device, if not empty:
122 // fully specifies a device, and satisfies def().device().
123 // TODO(josh11b): Move assigned_device_name outside of Node into a
124 // NodeId->DeviceName map.
125 const std::string& assigned_device_name() const;
126 void set_assigned_device_name(const std::string& device_name);
has_assigned_device_name()127 bool has_assigned_device_name() const {
128 return assigned_device_name_index_ > 0;
129 }
assigned_device_name_index()130 int assigned_device_name_index() const { return assigned_device_name_index_; }
131 void set_assigned_device_name_index(int index);
132
133 // Sets 'original_node_names' field of this node's DebugInfo proto to
134 // 'names'.
135 void set_original_node_names(const std::vector<string>& names);
136
137 // Read only access to attributes
138 AttrSlice attrs() const;
139
140 // Inputs requested by the NodeDef. For the actual inputs, use in_edges.
141 const protobuf::RepeatedPtrField<string>& requested_inputs() const;
142
143 // Get the neighboring nodes via edges either in or out of this node. This
144 // includes control edges.
145 gtl::iterator_range<NeighborIter> in_nodes() const;
146 gtl::iterator_range<NeighborIter> out_nodes() const;
in_edges()147 const EdgeSet& in_edges() const { return in_edges_; }
out_edges()148 const EdgeSet& out_edges() const { return out_edges_; }
149
150 // Node type helpers.
IsSource()151 bool IsSource() const { return id() == 0; }
IsSink()152 bool IsSink() const { return id() == 1; }
153 // Anything other than the special Source & Sink nodes.
IsOp()154 bool IsOp() const { return id() > 1; }
155
156 // Node class helpers
IsSwitch()157 bool IsSwitch() const { return class_ == NC_SWITCH; }
IsMerge()158 bool IsMerge() const { return class_ == NC_MERGE; }
IsEnter()159 bool IsEnter() const { return class_ == NC_ENTER; }
IsExit()160 bool IsExit() const { return class_ == NC_EXIT; }
IsNextIteration()161 bool IsNextIteration() const { return class_ == NC_NEXT_ITERATION; }
IsLoopCond()162 bool IsLoopCond() const { return class_ == NC_LOOP_COND; }
IsControlTrigger()163 bool IsControlTrigger() const { return class_ == NC_CONTROL_TRIGGER; }
IsSend()164 bool IsSend() const { return class_ == NC_SEND || class_ == NC_HOST_SEND; }
IsRecv()165 bool IsRecv() const { return class_ == NC_RECV || class_ == NC_HOST_RECV; }
IsConstant()166 bool IsConstant() const { return class_ == NC_CONSTANT; }
IsVariable()167 bool IsVariable() const { return class_ == NC_VARIABLE; }
IsIdentity()168 bool IsIdentity() const { return class_ == NC_IDENTITY; }
IsGetSessionHandle()169 bool IsGetSessionHandle() const { return class_ == NC_GET_SESSION_HANDLE; }
IsGetSessionTensor()170 bool IsGetSessionTensor() const { return class_ == NC_GET_SESSION_TENSOR; }
IsDeleteSessionTensor()171 bool IsDeleteSessionTensor() const {
172 return class_ == NC_DELETE_SESSION_TENSOR;
173 }
IsControlFlow()174 bool IsControlFlow() const {
175 return (class_ != NC_OTHER) && // Fast path
176 (IsSwitch() || IsMerge() || IsEnter() || IsExit() ||
177 IsNextIteration());
178 }
IsHostSend()179 bool IsHostSend() const { return class_ == NC_HOST_SEND; }
IsHostRecv()180 bool IsHostRecv() const { return class_ == NC_HOST_RECV; }
IsScopedAllocator()181 bool IsScopedAllocator() const { return class_ == NC_SCOPED_ALLOCATOR; }
IsCollective()182 bool IsCollective() const { return class_ == NC_COLLECTIVE; }
183
IsMetadata()184 bool IsMetadata() const { return class_ == NC_METADATA; }
IsFakeParam()185 bool IsFakeParam() const { return class_ == NC_FAKE_PARAM; }
IsPartitionedCall()186 bool IsPartitionedCall() const { return class_ == NC_PARTITIONED_CALL; }
187
188 // Returns true if this node is any kind of function call node.
189 //
190 // NOTE: "function call nodes" include partitioned call ops, symbolic gradient
191 // ops, and ops whose type_string is the name of a function ("function ops").
IsFunctionCall()192 bool IsFunctionCall() const {
193 return class_ == NC_PARTITIONED_CALL || class_ == NC_FUNCTION_OP ||
194 class_ == NC_SYMBOLIC_GRADIENT;
195 }
196
IsIfNode()197 bool IsIfNode() const { return class_ == NC_IF; }
IsWhileNode()198 bool IsWhileNode() const { return class_ == NC_WHILE; }
IsCaseNode()199 bool IsCaseNode() const { return class_ == NC_CASE; }
200 // Is this node a function input
IsArg()201 bool IsArg() const { return class_ == NC_ARG; }
202 // Is this node a function output
IsRetval()203 bool IsRetval() const { return class_ == NC_RETVAL; }
204
205 template <typename T>
AddAttr(const std::string & name,const T & val)206 void AddAttr(const std::string& name, const T& val) {
207 SetAttrValue(val, AddAttrHelper(name));
208 UpdateProperties();
209 }
210
AddAttr(const std::string & name,std::vector<string> && val)211 void AddAttr(const std::string& name, std::vector<string>&& val) {
212 MoveAttrValue(std::move(val), AddAttrHelper(name));
213 UpdateProperties();
214 }
215
216 void ClearAttr(const std::string& name);
217
218 // Returns into '*e' the edge connecting to the 'idx' input of this Node.
219 Status input_edge(int idx, const Edge** e) const;
220
221 // Returns into '*edges' the input data edges of this Node, indexed by input
222 // number. Does not return control edges.
223 Status input_edges(std::vector<const Edge*>* edges) const;
224
225 // Returns into '*n' the node that has an output connected to the
226 // 'idx' input of this Node.
227 Status input_node(int idx, const Node** n) const;
228 Status input_node(int idx, Node** n) const;
229
230 // Returns into '*t' the idx-th input tensor of this node, represented as the
231 // output tensor of input_node(idx).
232 Status input_tensor(int idx, OutputTensor* t) const;
233
while_ctx()234 WhileContext* while_ctx() const { return while_ctx_; }
set_while_ctx(WhileContext * while_ctx)235 void set_while_ctx(WhileContext* while_ctx) {
236 DCHECK(IsExit());
237 DCHECK(while_ctx_ == nullptr);
238 while_ctx_ = while_ctx;
239 }
240
properties()241 std::shared_ptr<NodeProperties> properties() const { return props_; }
242
243 // Sets the stack trace for the node. Assumes that getting and setting the
244 // stack trace for a given node will not race.
SetStackTrace(const std::shared_ptr<AbstractStackTrace> & stack_trace)245 void SetStackTrace(const std::shared_ptr<AbstractStackTrace>& stack_trace) {
246 stack_trace_ = stack_trace;
247 }
248
249 // Get the stack trace for when the node was instantiated.
GetStackTrace()250 const std::shared_ptr<AbstractStackTrace>& GetStackTrace() const {
251 return stack_trace_;
252 }
253
254 private:
255 friend class Graph;
256 Node();
257
258 // Stack trace for the user code for node instantiation. Can be shared across
259 // multiple nodes (e.g. when inlining).
260 std::shared_ptr<AbstractStackTrace> stack_trace_;
261
262 // Releases memory from props_, in addition to restoring *this to its
263 // uninitialized state.
264 void Clear();
265
266 // Make a copy of the Node's props_ if props_ is shared with
267 // other nodes. This must be called before mutating properties,
268 // e.g. in AddAttr.
269 void MaybeCopyOnWrite();
270
271 // Called after an attr has changed. Decides whether we need to update some
272 // property of the node (stored in props_).
273 void UpdateProperties();
274
275 AttrValue* AddAttrHelper(const std::string& name);
276
277 // A set of mutually exclusive classes for different kinds of nodes,
278 // class_ is initialized in the Node::Initialize routine based on the
279 // node's type_string().
280 enum NodeClass {
281 NC_UNINITIALIZED,
282 NC_SWITCH,
283 NC_MERGE,
284 NC_ENTER,
285 NC_EXIT,
286 NC_NEXT_ITERATION,
287 NC_LOOP_COND,
288 NC_CONTROL_TRIGGER,
289 NC_SEND,
290 NC_HOST_SEND,
291 NC_RECV,
292 NC_HOST_RECV,
293 NC_CONSTANT,
294 NC_VARIABLE,
295 NC_IDENTITY,
296 NC_GET_SESSION_HANDLE,
297 NC_GET_SESSION_TENSOR,
298 NC_DELETE_SESSION_TENSOR,
299 NC_METADATA,
300 NC_SCOPED_ALLOCATOR,
301 NC_COLLECTIVE,
302 NC_FAKE_PARAM,
303 NC_PARTITIONED_CALL,
304 NC_FUNCTION_OP,
305 NC_SYMBOLIC_GRADIENT,
306 NC_IF,
307 NC_WHILE,
308 NC_CASE,
309 NC_ARG,
310 NC_RETVAL,
311 NC_OTHER // Not a special kind of node
312 };
313
314 void Initialize(int id, int cost_id, std::shared_ptr<NodeProperties> props,
315 NodeClass node_class);
316
317 static NodeClass GetNodeClassForOp(const std::string& ts);
318
319 int id_; // -1 until Initialize() is called
320 int cost_id_; // -1 if there is no corresponding cost accounting node
321 NodeClass class_;
322
323 EdgeSet in_edges_;
324 EdgeSet out_edges_;
325
326 // NOTE(skyewm): inheriting from core::RefCounted may have a slight
327 // performance benefit over using shared_ptr, at the cost of manual ref
328 // counting
329 std::shared_ptr<NodeProperties> props_;
330
331 // Index within Graph::device_names_ of the name of device assigned
332 // to perform this computation.
333 int assigned_device_name_index_;
334
335 // A back-pointer to the Graph that owns this node. Currently, this exists
336 // solely to allow Node::[set_]assigned_device_name() to work. However, if all
337 // callers of Node::[set_]assigned_device_name() are modified to use the
338 // equivalent methods defined directly on Graph, then we can remove this
339 // field and reclaim that memory.
340 Graph* graph_;
341
342 // Set if this is an exit node of a while loop with an associated
343 // WhileContext. Otherwise null. (This is only set for exit nodes because
344 // they're the first nodes of a loop encountered while creating the gradient
345 // graph. Exit nodes that are part of while loop gradient graphs will not have
346 // this set.)
347 WhileContext* while_ctx_;
348
349 TF_DISALLOW_COPY_AND_ASSIGN(Node);
350 };
351
352 // Stores debug information associated with the Node.
353 struct NodeDebugInfo {
354 const std::string name;
355 std::vector<string> original_node_names;
356
357 NodeDebugInfo(const Node& n);
358 NodeDebugInfo(const NodeDef& ndef);
359 NodeDebugInfo(StringPiece node_name, bool has_experimental_debug_info,
360 const NodeDef_ExperimentalDebugInfo& experimental_debug_info);
361 };
362
363 // Represents an input of a node, i.e., the `index`-th input to `node`.
364 struct InputTensor {
365 Node* node;
366 int index;
367
InputTensorInputTensor368 InputTensor(Node* n, int i) : node(n), index(i) {}
InputTensorInputTensor369 InputTensor() : node(nullptr), index(0) {}
370
371 // Returns true if this InputTensor is identical to 'other'. Nodes are
372 // compared using pointer equality.
373 bool operator==(const InputTensor& other) const;
374
375 // A hash function for InputTensors. Nodes are hashed based on their pointer
376 // value.
377 struct Hash {
378 uint64 operator()(InputTensor const& s) const;
379 };
380 };
381
382 // Represents an output of a node, i.e., the `index`-th output of `node`. Note
383 // that a single `OutputTensor` can correspond to multiple `Edge`s if the output
384 // is consumed by multiple destination nodes.
385 struct OutputTensor {
386 Node* node;
387 int index;
388
OutputTensorOutputTensor389 OutputTensor(Node* n, int i) : node(n), index(i) {}
OutputTensorOutputTensor390 OutputTensor() : node(nullptr), index(0) {}
391
392 // Returns true if this OutputTensor is identical to 'other'. Nodes are
393 // compared using pointer equality.
394 bool operator==(const OutputTensor& other) const;
395
396 // A hash function for OutputTensors. Nodes are hashed based on their pointer
397 // value.
398 struct Hash {
399 uint64 operator()(OutputTensor const& s) const;
400 };
401 };
402
403 class Edge {
404 public:
src()405 Node* src() const { return src_; }
dst()406 Node* dst() const { return dst_; }
id()407 int id() const { return id_; }
408
409 // Return the index of the source output that produces the data
410 // carried by this edge. The special value kControlSlot is used
411 // for control dependencies.
src_output()412 int src_output() const { return src_output_; }
413
414 // Return the index of the destination input that consumes the data
415 // carried by this edge. The special value kControlSlot is used
416 // for control dependencies.
dst_input()417 int dst_input() const { return dst_input_; }
418
419 // Return true iff this is an edge that indicates a control-flow
420 // (as opposed to a data-flow) dependency.
421 bool IsControlEdge() const;
422
423 std::string DebugString() const;
424
425 private:
Edge()426 Edge() {}
427
428 friend class EdgeSetTest;
429 friend class Graph;
430 Node* src_;
431 Node* dst_;
432 int id_;
433 int src_output_;
434 int dst_input_;
435 };
436
437 // Allows for iteration of the edges of a Graph, by iterating the underlying
438 // Graph.edges_ vector while skipping over null entries.
439 class GraphEdgesIterable {
440 private:
441 const std::vector<Edge*>& edges_;
442
443 public:
GraphEdgesIterable(const std::vector<Edge * > & edges)444 explicit GraphEdgesIterable(const std::vector<Edge*>& edges)
445 : edges_(edges) {}
446
447 typedef Edge* value_type;
448
449 class const_iterator {
450 private:
451 // The underlying iterator.
452 std::vector<value_type>::const_iterator iter_;
453
454 // The end of the underlying iterator.
455 std::vector<value_type>::const_iterator end_;
456
457 // Advances iter_ until it reaches a non-null item, or reaches the end.
apply_filter()458 void apply_filter() {
459 while (iter_ != end_ && *iter_ == nullptr) {
460 ++iter_;
461 }
462 }
463
464 public:
const_iterator(std::vector<value_type>::const_iterator iter,std::vector<value_type>::const_iterator end)465 const_iterator(std::vector<value_type>::const_iterator iter,
466 std::vector<value_type>::const_iterator end)
467 : iter_(iter), end_(end) {
468 apply_filter();
469 }
470
471 bool operator==(const const_iterator& other) const {
472 return iter_ == other.iter_;
473 }
474
475 bool operator!=(const const_iterator& other) const {
476 return iter_ != other.iter_;
477 }
478
479 // This is the prefix increment operator (++x), which is the operator
480 // used by C++ range iteration (for (x : y) ...). We intentionally do not
481 // provide a postfix increment operator.
482 const_iterator& operator++() {
483 ++iter_;
484 apply_filter();
485 return *this;
486 }
487
488 value_type operator*() { return *iter_; }
489 };
490
begin()491 const_iterator begin() {
492 return const_iterator(edges_.begin(), edges_.end());
493 }
end()494 const_iterator end() { return const_iterator(edges_.end(), edges_.end()); }
495 };
496
497 // Thread compatible but not thread safe.
498 class Graph {
499 public:
500 // Constructs a graph with a single SOURCE (always id kSourceId) and a
501 // single SINK (always id kSinkId) node, and an edge from SOURCE->SINK.
502 //
503 // The graph can hold ops found in the registry. `ops`s lifetime must be at
504 // least that of the constructed graph's.
505 explicit Graph(const OpRegistryInterface* ops);
506
507 // Constructs a graph with a single SOURCE (always id kSourceId) and a
508 // single SINK (always id kSinkId) node, and an edge from SOURCE->SINK.
509 //
510 // The graph can hold ops found in `flib_def`. Unlike the constructor taking
511 // an OpRegistryInterface, this constructor copies the function definitions in
512 // `flib_def` so its lifetime may be shorter than that of the graph's. The
513 // OpRegistryInterface backing `flib_def` must still have the lifetime of the
514 // graph though.
515 explicit Graph(const FunctionLibraryDefinition& flib_def);
516
517 ~Graph();
518
519 static const int kControlSlot;
520
521 // The GraphDef version range of this graph (see graph.proto).
522 const VersionDef& versions() const;
523 void set_versions(const VersionDef& versions);
524
525 // Adds a new node to this graph, and returns it. Infers the Op and
526 // input/output types for the node. *this owns the returned instance.
527 // Returns nullptr and sets *status on error.
528 Node* AddNode(NodeDef node_def, Status* status);
529
530 // Copies *node, which may belong to another graph, to a new node,
531 // which is returned. Does not copy any edges. *this owns the
532 // returned instance.
533 Node* CopyNode(const Node* node);
534
535 // Removes a node from this graph, including all edges from or to it.
536 // *node should not be accessed after calling this function.
537 // REQUIRES: node->IsOp()
538 void RemoveNode(Node* node);
539
540 void Copy(const Graph& src);
541
542 // Adds an edge that connects the xth output of `source` to the yth input of
543 // `dest` and returns it. Does not update dest's NodeDef.
544 const Edge* AddEdge(Node* source, int x, Node* dest, int y);
545
546 // Adds a control edge (no data flows along this edge) that connects `source`
547 // to `dest`. If `dest`s NodeDef is missing the corresponding control input,
548 // adds the control input.
549 //
550 // If such a control edge already exists and `allow_duplicates` is false, no
551 // edge is added and the function returns nullptr. Otherwise the edge is
552 // unconditionally created and returned. The NodeDef is not updated if
553 // `allow_duplicates` is true.
554 // TODO(skyewm): // TODO(skyewm): allow_duplicates is needed only by
555 // graph_partition.cc. Figure out if we can do away with it.
556 const Edge* AddControlEdge(Node* source, Node* dest,
557 bool allow_duplicates = false);
558
559 // Removes edge from the graph. Does not update the destination node's
560 // NodeDef.
561 // REQUIRES: The edge must exist.
562 void RemoveEdge(const Edge* edge);
563
564 // Removes control edge `edge` from the graph. Note that this also updates
565 // the corresponding NodeDef to reflect the change.
566 // REQUIRES: The control edge must exist.
567 void RemoveControlEdge(const Edge* e);
568
569 // Updates the input to a node. The existing edge to `dst` is removed and an
570 // edge from `new_src` to `dst` is created. The NodeDef associated with `dst`
571 // is also updated.
572 Status UpdateEdge(Node* new_src, int new_src_index, Node* dst, int dst_index);
573
574 // Like AddEdge but updates dst's NodeDef. Used to add an input edge to a
575 // "While" op during gradient construction, see AddInputWhileHack in
576 // python_api.h for more details.
577 Status AddWhileInputHack(Node* new_src, int new_src_index, Node* dst);
578
579 // Adds the function and gradient definitions in `fdef_lib` to this graph's op
580 // registry. Ignores duplicate functions, and returns a bad status if an
581 // imported function differs from an existing function or op with the same
582 // name.
583 Status AddFunctionLibrary(const FunctionDefLibrary& fdef_lib);
584
585 // The number of live nodes in the graph.
586 //
587 // Because nodes can be removed from the graph, num_nodes() is often
588 // smaller than num_node_ids(). If one needs to create an array of
589 // nodes indexed by node ids, num_node_ids() should be used as the
590 // array's size.
num_nodes()591 int num_nodes() const { return num_nodes_; }
592
593 // The number of live nodes in the graph, excluding the Source and Sink nodes.
num_op_nodes()594 int num_op_nodes() const {
595 DCHECK_GE(num_nodes_, 2);
596 return num_nodes_ - 2;
597 }
598
599 // The number of live edges in the graph.
600 //
601 // Because edges can be removed from the graph, num_edges() is often
602 // smaller than num_edge_ids(). If one needs to create an array of
603 // edges indexed by edge ids, num_edge_ids() should be used as the
604 // array's size.
num_edges()605 int num_edges() const { return num_edges_; }
606
607 // Serialize the nodes starting at `from_node_id` to a GraphDef.
608 void ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const;
609
610 // Serialize to a GraphDef.
611 void ToGraphDef(GraphDef* graph_def) const;
612
613 // This version can be called from debugger to inspect the graph content.
614 // Use the previous version outside debug context for efficiency reasons.
615 //
616 // Note: We do not expose a DebugString() API, since GraphDef.DebugString() is
617 // not defined in some TensorFlow builds.
618 GraphDef ToGraphDefDebug() const;
619
620 // Generate new node name with the specified prefix that is unique
621 // across this graph.
622 std::string NewName(StringPiece prefix);
623
624 // Access to the list of all nodes. Example usage:
625 // for (Node* node : graph.nodes()) { ... }
626 gtl::iterator_range<NodeIter> nodes() const;
627
628 // Access to the list of all nodes, excluding the Source and Sink nodes.
629 gtl::iterator_range<NodeIter> op_nodes() const;
630
631 // Returns one more than the maximum id assigned to any node.
num_node_ids()632 int num_node_ids() const { return nodes_.size(); }
633
634 // Returns the node associated with an id, or nullptr if no node
635 // with that id (the node with that id was removed and the id has
636 // not yet been re-used). *this owns the returned instance.
637 // REQUIRES: 0 <= id < num_node_ids().
FindNodeId(int id)638 Node* FindNodeId(int id) const { return nodes_[id]; }
639
640 // Returns one more than the maximum id assigned to any edge.
num_edge_ids()641 int num_edge_ids() const { return edges_.size(); }
642
643 // Returns the Edge associated with an id, or nullptr if no edge
644 // with that id (the edge with that id was removed and the id has
645 // not yet been re-used). *this owns the returned instance.
646 // REQUIRES: 0 <= id < num_edge_ids().
FindEdgeId(int id)647 const Edge* FindEdgeId(int id) const { return edges_[id]; }
648
649 // Access to the set of all edges. Example usage:
650 // for (const Edge* e : graph.edges()) { ... }
edges()651 GraphEdgesIterable edges() const { return GraphEdgesIterable(edges_); }
652
653 // The pre-defined nodes.
654 enum { kSourceId = 0, kSinkId = 1 };
source_node()655 Node* source_node() const { return FindNodeId(kSourceId); }
sink_node()656 Node* sink_node() const { return FindNodeId(kSinkId); }
657
op_registry()658 const OpRegistryInterface* op_registry() const { return &ops_; }
flib_def()659 const FunctionLibraryDefinition& flib_def() const { return ops_; }
660
CheckDeviceNameIndex(int index)661 void CheckDeviceNameIndex(int index) {
662 DCHECK_GE(index, 0);
663 DCHECK_LT(index, static_cast<int>(device_names_.size()));
664 }
665
666 int InternDeviceName(const std::string& device_name);
667
get_assigned_device_name(const Node & node)668 const std::string& get_assigned_device_name(const Node& node) const {
669 return device_names_[node.assigned_device_name_index()];
670 }
671
set_assigned_device_name_index(Node * node,int device_name_index)672 void set_assigned_device_name_index(Node* node, int device_name_index) {
673 CheckDeviceNameIndex(device_name_index);
674 node->assigned_device_name_index_ = device_name_index;
675 }
676
set_assigned_device_name(Node * node,const std::string & device_name)677 void set_assigned_device_name(Node* node, const std::string& device_name) {
678 node->assigned_device_name_index_ = InternDeviceName(device_name);
679 }
680
681 // Returns OK if `node` is non-null and belongs to this graph
682 Status IsValidNode(const Node* node) const;
683
684 // Returns OK if IsValidNode(`node`) and `idx` is a valid output. Does not
685 // accept control outputs.
686 Status IsValidOutputTensor(const Node* node, int idx) const;
687
688 // Returns OK if IsValidNode(`node`) and `idx` a valid input. Does not accept
689 // control inputs.
690 Status IsValidInputTensor(const Node* node, int idx) const;
691
692 // Create and return a new WhileContext owned by this graph. This is called
693 // when a new while loop is created. `frame_name` must be unique among
694 // WhileContexts in this graph.
695 Status AddWhileContext(StringPiece frame_name, std::vector<Node*> enter_nodes,
696 std::vector<Node*> exit_nodes,
697 OutputTensor cond_output,
698 std::vector<OutputTensor> body_inputs,
699 std::vector<OutputTensor> body_outputs,
700 WhileContext** result);
701
702 // Builds a node name to node pointer index for all nodes in the graph.
703 std::unordered_map<string, Node*> BuildNodeNameIndex() const;
704
GetConstArgIndicesCache()705 absl::optional<std::vector<bool>>& GetConstArgIndicesCache() const {
706 return const_arg_indices_cache_;
707 }
708
709 // TODO(kkb): Add to the constructor when it becomes managable.
710 // Sets the graph construction context.
SetConstructionContext(ConstructionContext construction_context)711 void SetConstructionContext(ConstructionContext construction_context) {
712 construction_context_ = construction_context;
713 }
714
715 // TODO(kkb): Rename to `GetConstructionContext` once we're comfortable
716 // making this stable and make it available widely.
717 // Returns the graph construction context. It's `kUnknown` if not set.
GetConstructionContextInternal()718 ConstructionContext GetConstructionContextInternal() const {
719 return construction_context_;
720 }
721
722 // TODO(josh11b): uint64 hash() const;
723
724 private:
725 // If cost_node is non-null, then cost accounting (in CostModel)
726 // will be associated with that node rather than the new one being
727 // created.
728 //
729 // Ownership of the returned Node is not transferred to caller.
730 Node* AllocateNode(std::shared_ptr<NodeProperties> props,
731 const Node* cost_node, Node::NodeClass node_class);
732 void ReleaseNode(Node* node);
733 // Insert edge in free_edges_ for possible reuse.
734 void RecycleEdge(const Edge* edge);
735 // Registry of all known ops, including functions.
736 FunctionLibraryDefinition ops_;
737
738 // GraphDef versions
739 const std::unique_ptr<VersionDef> versions_;
740
741 // Allocator which will give us good locality.
742 core::Arena arena_;
743
744 // Map from node ids to allocated nodes. nodes_[id] may be nullptr if
745 // the node with that id was removed from the graph.
746 std::vector<Node*> nodes_;
747
748 // Number of nodes alive.
749 int64 num_nodes_ = 0;
750
751 // Map from edge ids to allocated edges. edges_[id] may be nullptr if
752 // the edge with that id was removed from the graph.
753 std::vector<Edge*> edges_;
754
755 // The number of entries in edges_ that are not nullptr.
756 int num_edges_ = 0;
757
758 // Allocated but free nodes and edges.
759 std::vector<Node*> free_nodes_;
760 std::vector<Edge*> free_edges_;
761
762 // For generating unique names.
763 int name_counter_ = 0;
764
765 // In most graphs, the number of unique values used for the
766 // Node::assigned_device_name() property is quite small. If the graph is
767 // large, then this duplication of values can consume a significant amount of
768 // memory. Instead, we represent the same information using an interning
769 // table, which consists of a vector of unique strings (device_names_), as
770 // well a map (device_names_map_) from unique strings to indices within the
771 // unique string table.
772 //
773 // The InternDeviceName() method handles adding a new entry into the table,
774 // or locating the index of an existing entry.
775 //
776 // The fact that Node::assigned_device_name() is implemented using an
777 // interning table is intentionally public. This allows algorithms that
778 // frequently access this field to do so efficiently, especially for the case
779 // where the assigned_device_name of one Node is copied directly from that
780 // of another Node.
781
782 // A table of the unique assigned device names. Indices do NOT correspond
783 // to node IDs. Index 0 is always the empty string.
784 std::vector<string> device_names_;
785
786 // Maps unique device names to indices within device_names_[i].
787 std::unordered_map<string, int> device_names_map_;
788
789 // All the while contexts owned by this graph, keyed by frame name,
790 // corresponding to all the while loops contained in this graph (including
791 // nested loops). The stored contexts are usually accessed via
792 // AddWhileContext() or Node::while_ctx(), but this manages the lifetime.
793 std::map<string, WhileContext> while_ctxs_;
794
795 // Cache of the indices of the arguments which need to be constant for the XLA
796 // compilation.
797 mutable absl::optional<std::vector<bool>> const_arg_indices_cache_;
798
799 // Indicates the context that this Graph instance is constructed.
800 ConstructionContext construction_context_ = ConstructionContext::kNotTracked;
801
802 TF_DISALLOW_COPY_AND_ASSIGN(Graph);
803 };
804
805 // TODO(josh11b): We may want to support keeping an index on various
806 // node/edge attributes in a graph, particularly node names.
807
808 // Helper routines
809
IsSource(const Node * node)810 inline bool IsSource(const Node* node) { return node->IsSource(); }
IsSink(const Node * node)811 inline bool IsSink(const Node* node) { return node->IsSink(); }
IsSwitch(const Node * node)812 inline bool IsSwitch(const Node* node) { return node->IsSwitch(); }
IsMerge(const Node * node)813 inline bool IsMerge(const Node* node) { return node->IsMerge(); }
IsEnter(const Node * node)814 inline bool IsEnter(const Node* node) { return node->IsEnter(); }
IsExit(const Node * node)815 inline bool IsExit(const Node* node) { return node->IsExit(); }
IsNextIteration(const Node * n)816 inline bool IsNextIteration(const Node* n) { return n->IsNextIteration(); }
IsLoopCond(const Node * node)817 inline bool IsLoopCond(const Node* node) { return node->IsLoopCond(); }
IsControlTrigger(const Node * n)818 inline bool IsControlTrigger(const Node* n) { return n->IsControlTrigger(); }
IsSend(const Node * node)819 inline bool IsSend(const Node* node) { return node->IsSend(); }
IsRecv(const Node * node)820 inline bool IsRecv(const Node* node) { return node->IsRecv(); }
IsHostSend(const Node * node)821 inline bool IsHostSend(const Node* node) { return node->IsHostSend(); }
IsHostRecv(const Node * node)822 inline bool IsHostRecv(const Node* node) { return node->IsHostRecv(); }
823
824 // True for Nodes that mediate the transfer of values between processes.
IsTransferNode(const Node * n)825 inline bool IsTransferNode(const Node* n) { return IsSend(n) || IsRecv(n); }
826
IsConstant(const Node * node)827 inline bool IsConstant(const Node* node) { return node->IsConstant(); }
IsVariable(const Node * node)828 inline bool IsVariable(const Node* node) { return node->IsVariable(); }
IsIdentity(const Node * node)829 inline bool IsIdentity(const Node* node) { return node->IsIdentity(); }
830
831 // Returns true iff 'n' is a control flow node.
IsControlFlow(const Node * n)832 inline bool IsControlFlow(const Node* n) { return n->IsControlFlow(); }
833
834 // Returns true if the node only depends on its input's metadata
835 // (shape). Specifically, returns true for "Size", "Shape" and "Rank" ops.
IsMetadata(const Node * n)836 inline bool IsMetadata(const Node* n) { return n->IsMetadata(); }
837
IsScopedAllocator(const Node * n)838 inline bool IsScopedAllocator(const Node* n) { return n->IsScopedAllocator(); }
839
IsHostMemoryPreserving(const Node * node)840 inline bool IsHostMemoryPreserving(const Node* node) {
841 return IsIdentity(node) || IsControlFlow(node);
842 }
843
844 // NOTE: We declare Reference type of NodeIter and NeighborIter as Node* (see
845 // https://en.cppreference.com/w/cpp/iterator/iterator).
846
847 // Iterator for stepping through the nodes of a graph.
848 class NodeIter
849 : public std::iterator<std::forward_iterator_tag, Node, std::ptrdiff_t,
850 /*Pointer*/ Node*, /*Reference*/ Node*> {
851 public:
852 NodeIter(const Graph* graph, int id);
853 bool operator==(const NodeIter& rhs) const;
854 bool operator!=(const NodeIter& rhs) const;
855 void operator++();
856 reference operator*() const;
857 pointer operator->() const;
858
859 private:
860 // Invariant: id_ == graph_->num_node_ids() || graph_->FindId(id_) != nullptr
861 const Graph* graph_;
862 int id_;
863 };
864
865 // Iterator for stepping through the neighbors of a node.
866 class NeighborIter
867 : public std::iterator<std::forward_iterator_tag, Node, std::ptrdiff_t,
868 /*Pointer*/ Node*, /*Reference*/ Node*> {
869 public:
870 NeighborIter(EdgeSet::const_iterator iter, bool incoming);
871 bool operator==(const NeighborIter& rhs) const;
872 bool operator!=(const NeighborIter& rhs) const;
873 void operator++();
874 reference operator*() const;
875 pointer operator->() const;
876
877 private:
878 EdgeSet::const_iterator iter_;
879 bool incoming_;
880 };
881
882 // IMPLEMENTATION DETAILS, PLEASE IGNORE
883
NodeIter(const Graph * graph,int id)884 inline NodeIter::NodeIter(const Graph* graph, int id)
885 : graph_(graph), id_(id) {}
886
887 inline bool NodeIter::operator==(const NodeIter& rhs) const {
888 DCHECK(graph_ == rhs.graph_);
889 return id_ == rhs.id_;
890 }
891
892 inline bool NodeIter::operator!=(const NodeIter& rhs) const {
893 return !(*this == rhs);
894 }
895
896 inline void NodeIter::operator++() {
897 while (1) {
898 DCHECK_LE(id_, graph_->num_node_ids());
899 ++id_;
900 if (id_ >= graph_->num_node_ids() || graph_->FindNodeId(id_) != nullptr) {
901 return;
902 }
903 }
904 }
905
906 inline Node* NodeIter::operator*() const { return graph_->FindNodeId(id_); }
907
908 inline Node* NodeIter::operator->() const { return graph_->FindNodeId(id_); }
909
NeighborIter(EdgeSet::const_iterator iter,bool incoming)910 inline NeighborIter::NeighborIter(EdgeSet::const_iterator iter, bool incoming)
911 : iter_(iter), incoming_(incoming) {}
912
913 inline bool NeighborIter::operator==(const NeighborIter& rhs) const {
914 return iter_ == rhs.iter_ && incoming_ == rhs.incoming_;
915 }
916
917 inline bool NeighborIter::operator!=(const NeighborIter& rhs) const {
918 return !(*this == rhs);
919 }
920
921 inline void NeighborIter::operator++() { ++iter_; }
922
923 inline Node* NeighborIter::operator*() const {
924 const Edge* e = *iter_;
925 return incoming_ ? e->src() : e->dst();
926 }
927
928 inline Node* NeighborIter::operator->() const {
929 const Edge* e = *iter_;
930 return incoming_ ? e->src() : e->dst();
931 }
932
IsControlEdge()933 inline bool Edge::IsControlEdge() const {
934 // Note that if either src_output_ or dst_input_ is kControlSlot,
935 // so is the other one (AddEdge checks this).
936 return src_output_ == Graph::kControlSlot;
937 }
938
nodes()939 inline gtl::iterator_range<NodeIter> Graph::nodes() const {
940 // Note that NodeId 0 is always valid since we don't let the source
941 // node be removed from the graph.
942 return gtl::make_range(NodeIter(this, 0), NodeIter(this, num_node_ids()));
943 }
944
op_nodes()945 inline gtl::iterator_range<NodeIter> Graph::op_nodes() const {
946 // Note that NodeId 0 is always valid since we don't let the source
947 // node be removed from the graph.
948 //
949 // The current implementation of Graph maintains the invariant that the
950 // first two nodes are the source and sink nodes, and all other nodes are op
951 // nodes. This method (op_nodes()) relies on this invariant.
952 NodeIter begin(this, 0);
953 NodeIter end(this, num_node_ids());
954 if (begin != end) {
955 ++begin;
956 }
957 if (begin != end) {
958 ++begin;
959 }
960 return gtl::make_range(begin, end);
961 }
962
set_assigned_device_name_index(int index)963 inline void Node::set_assigned_device_name_index(int index) {
964 graph_->CheckDeviceNameIndex(index);
965 assigned_device_name_index_ = index;
966 }
967
set_assigned_device_name(const std::string & device_name)968 inline void Node::set_assigned_device_name(const std::string& device_name) {
969 graph_->set_assigned_device_name(this, device_name);
970 }
971
assigned_device_name()972 inline const std::string& Node::assigned_device_name() const {
973 return graph_->get_assigned_device_name(*this);
974 }
975
976 } // namespace tensorflow
977
978 #endif // TENSORFLOW_CORE_GRAPH_GRAPH_H_
979