1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // A Graph describes a set of computations that are to be
17 // performed, as well as the dependencies between those
18 // computations. The basic model is a DAG (directed acyclic graph) with
19 // * internal nodes representing computational operations to be performed;
20 // * edges represent dependencies, indicating the target may only be
21 // executed once the source has completed; and
22 // * predefined "source" (start) and "sink" (finish) nodes -- the source
23 // should be the only node that doesn't depend on anything, and the sink
24 // should be the only node that nothing depends on.
25 //
26 // Note: Node ids are intended to be relatively dense in the
27 // 0..max_id range, but there may be gaps since ids won't be reused.
28 //
29 // Note: Some dependencies between operations are due to one operation
30 // consuming the output of another. In fact operations can produce
31 // multiple outputs and consume multiple inputs, and some
32 // optimizations will care about which specific outputs are connected
33 // to which specific inputs. We therefore represent data dependency
34 // between output O of layer A and input I of layer B using
35 // "input index" and "output index" labels per edge.
36
37 #ifndef TENSORFLOW_CORE_GRAPH_GRAPH_H_
38 #define TENSORFLOW_CORE_GRAPH_GRAPH_H_
39
40 #include <functional>
41 #include <memory>
42 #include <string>
43 #include <vector>
44
45 #include "absl/types/optional.h"
46 #include "tensorflow/core/framework/full_type.pb.h"
47 #include "tensorflow/core/framework/function.h"
48 #include "tensorflow/core/framework/node_def.pb.h"
49 #include "tensorflow/core/framework/node_def_util.h"
50 #include "tensorflow/core/framework/op.h"
51 #include "tensorflow/core/framework/types.h"
52 #include "tensorflow/core/graph/edgeset.h"
53 #include "tensorflow/core/lib/core/arena.h"
54 #include "tensorflow/core/lib/core/refcount.h"
55 #include "tensorflow/core/lib/core/status.h"
56 #include "tensorflow/core/lib/gtl/iterator_range.h"
57 #include "tensorflow/core/platform/logging.h"
58 #include "tensorflow/core/platform/macros.h"
59 #include "tensorflow/core/platform/stringpiece.h"
60 #include "tensorflow/core/platform/types.h"
61
62 namespace tensorflow {
63
64 class Edge;
65 class EdgeSetTest;
66 class Graph;
67 class GraphDef;
68 class Node;
69 struct OutputTensor;
70 class VersionDef;
71 class WhileContext;
72
73 class NeighborIter; // Declared below
74 class NodeIter; // Declared below
75
76 // Indicates where the graph instance is originated from.
77 enum class ConstructionContext {
78 kNotTracked, // Not tracked.
79 kDirectSession, // From `tensorflow::DirectSession`, TF1 session API.
80 kEagerRuntime, // Registered from TF2 eager runtime.
81 };
82
83 class Node {
84 public:
85 std::string DebugString() const;
id()86 int id() const { return id_; }
cost_id()87 int cost_id() const { return cost_id_; }
88 const std::string& name() const;
89 void set_name(std::string name);
90 const std::string& type_string() const;
91
92 // def() provides the NodeDef the user supplied, but the specifics
93 // of this Node may have changed due to placement, optimization, etc.
94 // In particular:
95 // * def().name() will match name();
96 // * def().op() will match type_string() and op_def().name();
97 // * def().input() is not reliable, use "in_edges()" below instead;
98 // * def().device() is the "user's requested device" and may not match
99 // the actual assigned device, see assigned_device_name() below;
100 // * def().attr() is authoritative.
101 // TODO(irving): Replace with NodeInfo.
102 const NodeDef& def() const;
103 const OpDef& op_def() const;
104
105 // TODO(mdan): This is only used by control_flow_deps_o_chains. Remove?
106 NodeDef* mutable_def();
107
108 // input and output types
109 int32 num_inputs() const;
110 DataType input_type(int32_t i) const;
111 const DataTypeVector& input_types() const;
112
113 int32 num_outputs() const;
114 DataType output_type(int32_t o) const;
115 const DataTypeVector& output_types() const;
116
117 // The device requested by the user. For the actual assigned device,
118 // use assigned_device_name() below.
119 const std::string& requested_device() const;
120
121 // This changes the user requested device but not necessarily the device that
122 // on which the operation will run.
123 void set_requested_device(const std::string& device);
124
125 // This gives the device the runtime has assigned this node to. If
126 // you want the device the user requested, use def().device() instead.
127 // TODO(josh11b): Validate that the assigned_device, if not empty:
128 // fully specifies a device, and satisfies def().device().
129 // TODO(josh11b): Move assigned_device_name outside of Node into a
130 // NodeId->DeviceName map.
131 const std::string& assigned_device_name() const;
132 void set_assigned_device_name(const std::string& device_name);
has_assigned_device_name()133 bool has_assigned_device_name() const {
134 return assigned_device_name_index_ > 0;
135 }
assigned_device_name_index()136 int assigned_device_name_index() const { return assigned_device_name_index_; }
137 void set_assigned_device_name_index(int index);
138
139 // Sets 'original_node_names' field of this node's DebugInfo proto to
140 // 'names'.
141 void set_original_node_names(const std::vector<string>& names);
142
143 // Read only access to attributes
144 AttrSlice attrs() const;
145
146 // Inputs requested by the NodeDef. For the actual inputs, use in_edges.
147 const protobuf::RepeatedPtrField<string>& requested_inputs() const;
148
149 // Get the neighboring nodes via edges either in or out of this node. This
150 // includes control edges.
151 gtl::iterator_range<NeighborIter> in_nodes() const;
152 gtl::iterator_range<NeighborIter> out_nodes() const;
in_edges()153 const EdgeSet& in_edges() const { return in_edges_; }
out_edges()154 const EdgeSet& out_edges() const { return out_edges_; }
155
156 // Node type helpers.
IsSource()157 bool IsSource() const { return id() == 0; }
IsSink()158 bool IsSink() const { return id() == 1; }
159 // Anything other than the special Source & Sink nodes.
IsOp()160 bool IsOp() const { return id() > 1; }
161
162 // Node class helpers
IsSwitch()163 bool IsSwitch() const { return class_ == NC_SWITCH; }
IsMerge()164 bool IsMerge() const { return class_ == NC_MERGE; }
IsEnter()165 bool IsEnter() const { return class_ == NC_ENTER; }
IsExit()166 bool IsExit() const { return class_ == NC_EXIT; }
IsNextIteration()167 bool IsNextIteration() const { return class_ == NC_NEXT_ITERATION; }
IsLoopCond()168 bool IsLoopCond() const { return class_ == NC_LOOP_COND; }
IsControlTrigger()169 bool IsControlTrigger() const { return class_ == NC_CONTROL_TRIGGER; }
IsSend()170 bool IsSend() const { return class_ == NC_SEND || class_ == NC_HOST_SEND; }
IsRecv()171 bool IsRecv() const { return class_ == NC_RECV || class_ == NC_HOST_RECV; }
IsConstant()172 bool IsConstant() const { return class_ == NC_CONSTANT; }
IsVariable()173 bool IsVariable() const { return class_ == NC_VARIABLE; }
IsIdentity()174 bool IsIdentity() const { return class_ == NC_IDENTITY; }
IsGetSessionHandle()175 bool IsGetSessionHandle() const { return class_ == NC_GET_SESSION_HANDLE; }
IsGetSessionTensor()176 bool IsGetSessionTensor() const { return class_ == NC_GET_SESSION_TENSOR; }
IsDeleteSessionTensor()177 bool IsDeleteSessionTensor() const {
178 return class_ == NC_DELETE_SESSION_TENSOR;
179 }
IsControlFlow()180 bool IsControlFlow() const {
181 return (class_ != NC_OTHER) && // Fast path
182 (IsSwitch() || IsMerge() || IsEnter() || IsExit() ||
183 IsNextIteration());
184 }
IsHostSend()185 bool IsHostSend() const { return class_ == NC_HOST_SEND; }
IsHostRecv()186 bool IsHostRecv() const { return class_ == NC_HOST_RECV; }
IsScopedAllocator()187 bool IsScopedAllocator() const { return class_ == NC_SCOPED_ALLOCATOR; }
IsCollective()188 bool IsCollective() const { return class_ == NC_COLLECTIVE; }
189
IsMetadata()190 bool IsMetadata() const { return class_ == NC_METADATA; }
IsFakeParam()191 bool IsFakeParam() const { return class_ == NC_FAKE_PARAM; }
IsPartitionedCall()192 bool IsPartitionedCall() const { return class_ == NC_PARTITIONED_CALL; }
193
194 // Returns true if this node is any kind of function call node.
195 //
196 // NOTE: "function call nodes" include partitioned call ops, symbolic gradient
197 // ops, and ops whose type_string is the name of a function ("function ops").
IsFunctionCall()198 bool IsFunctionCall() const {
199 return class_ == NC_PARTITIONED_CALL || class_ == NC_FUNCTION_OP ||
200 class_ == NC_SYMBOLIC_GRADIENT;
201 }
202
IsIfNode()203 bool IsIfNode() const { return class_ == NC_IF; }
IsWhileNode()204 bool IsWhileNode() const { return class_ == NC_WHILE; }
IsCaseNode()205 bool IsCaseNode() const { return class_ == NC_CASE; }
206 // Is this node a function input
IsArg()207 bool IsArg() const { return class_ == NC_ARG; }
208 // Is this node a function output
IsRetval()209 bool IsRetval() const { return class_ == NC_RETVAL; }
210
IsDistributedCommunication()211 bool IsDistributedCommunication() const {
212 return op_def().is_distributed_communication();
213 }
214
215 template <typename T>
AddAttr(const std::string & name,const T & val)216 void AddAttr(const std::string& name, const T& val) {
217 SetAttrValue(val, AddAttrHelper(name));
218 UpdateProperties();
219 }
220
AddAttr(const std::string & name,std::vector<string> && val)221 void AddAttr(const std::string& name, std::vector<string>&& val) {
222 MoveAttrValue(std::move(val), AddAttrHelper(name));
223 UpdateProperties();
224 }
225
226 void ClearAttr(const std::string& name);
227
228 // Returns into '*e' the edge connecting to the 'idx' input of this Node.
229 Status input_edge(int idx, const Edge** e) const;
230
231 // Returns into '*edges' the input data edges of this Node, indexed by input
232 // number. Does not return control edges.
233 Status input_edges(std::vector<const Edge*>* edges) const;
234
235 // Returns into '*n' the node that has an output connected to the
236 // 'idx' input of this Node.
237 Status input_node(int idx, const Node** n) const;
238 Status input_node(int idx, Node** n) const;
239
240 // Returns into '*t' the idx-th input tensor of this node, represented as the
241 // output tensor of input_node(idx).
242 Status input_tensor(int idx, OutputTensor* t) const;
243
while_ctx()244 WhileContext* while_ctx() const { return while_ctx_; }
set_while_ctx(WhileContext * while_ctx)245 void set_while_ctx(WhileContext* while_ctx) {
246 DCHECK(IsExit());
247 DCHECK(while_ctx_ == nullptr);
248 while_ctx_ = while_ctx;
249 }
250
properties()251 std::shared_ptr<NodeProperties> properties() const { return props_; }
252
253 // Sets the stack trace for the node. Assumes that getting and setting the
254 // stack trace for a given node will not race.
SetStackTrace(const std::shared_ptr<AbstractStackTrace> & stack_trace)255 void SetStackTrace(const std::shared_ptr<AbstractStackTrace>& stack_trace) {
256 stack_trace_ = stack_trace;
257 }
258
259 // Get the stack trace for when the node was instantiated.
GetStackTrace()260 const std::shared_ptr<AbstractStackTrace>& GetStackTrace() const {
261 return stack_trace_;
262 }
263
264 // Called after an attr has changed. Decides whether we need to update some
265 // property of the node (stored in props_).
266 void UpdateProperties();
267
268 private:
269 friend class Graph;
270 Node();
271
272 // Stack trace for the user code for node instantiation. Can be shared across
273 // multiple nodes (e.g. when inlining).
274 std::shared_ptr<AbstractStackTrace> stack_trace_;
275
276 // Releases memory from props_, in addition to restoring *this to its
277 // uninitialized state.
278 void Clear();
279
280 // Make a copy of the Node's props_ if props_ is shared with
281 // other nodes. This must be called before mutating properties,
282 // e.g. in AddAttr.
283 void MaybeCopyOnWrite();
284
285 AttrValue* AddAttrHelper(const std::string& name);
286
287 // A set of mutually exclusive classes for different kinds of nodes,
288 // class_ is initialized in the Node::Initialize routine based on the
289 // node's type_string().
290 enum NodeClass {
291 NC_UNINITIALIZED,
292 NC_SWITCH,
293 NC_MERGE,
294 NC_ENTER,
295 NC_EXIT,
296 NC_NEXT_ITERATION,
297 NC_LOOP_COND,
298 NC_CONTROL_TRIGGER,
299 NC_SEND,
300 NC_HOST_SEND,
301 NC_RECV,
302 NC_HOST_RECV,
303 NC_CONSTANT,
304 NC_VARIABLE,
305 NC_IDENTITY,
306 NC_GET_SESSION_HANDLE,
307 NC_GET_SESSION_TENSOR,
308 NC_DELETE_SESSION_TENSOR,
309 NC_METADATA,
310 NC_SCOPED_ALLOCATOR,
311 NC_COLLECTIVE,
312 NC_FAKE_PARAM,
313 NC_PARTITIONED_CALL,
314 NC_FUNCTION_OP,
315 NC_SYMBOLIC_GRADIENT,
316 NC_IF,
317 NC_WHILE,
318 NC_CASE,
319 NC_ARG,
320 NC_RETVAL,
321 NC_OTHER // Not a special kind of node
322 };
323
324 void Initialize(int id, int cost_id, std::shared_ptr<NodeProperties> props,
325 NodeClass node_class);
326
327 static NodeClass GetNodeClassForOp(const std::string& ts);
328
329 int id_; // -1 until Initialize() is called
330 int cost_id_; // -1 if there is no corresponding cost accounting node
331 NodeClass class_;
332
333 EdgeSet in_edges_;
334 EdgeSet out_edges_;
335
336 // NOTE(skyewm): inheriting from core::RefCounted may have a slight
337 // performance benefit over using shared_ptr, at the cost of manual ref
338 // counting
339 std::shared_ptr<NodeProperties> props_;
340
341 // Index within Graph::device_names_ of the name of device assigned
342 // to perform this computation.
343 int assigned_device_name_index_;
344
345 // A back-pointer to the Graph that owns this node. Currently, this exists
346 // solely to allow Node::[set_]assigned_device_name() to work. However, if all
347 // callers of Node::[set_]assigned_device_name() are modified to use the
348 // equivalent methods defined directly on Graph, then we can remove this
349 // field and reclaim that memory.
350 Graph* graph_;
351
352 // Set if this is an exit node of a while loop with an associated
353 // WhileContext. Otherwise null. (This is only set for exit nodes because
354 // they're the first nodes of a loop encountered while creating the gradient
355 // graph. Exit nodes that are part of while loop gradient graphs will not have
356 // this set.)
357 WhileContext* while_ctx_;
358
359 TF_DISALLOW_COPY_AND_ASSIGN(Node);
360 };
361
362 // Stores debug information associated with the Node.
363 struct NodeDebugInfo {
364 const std::string name;
365 std::vector<string> original_node_names;
366
367 NodeDebugInfo(const Node& n);
368 NodeDebugInfo(const NodeDef& ndef);
369 NodeDebugInfo(StringPiece node_name, bool has_experimental_debug_info,
370 const NodeDef_ExperimentalDebugInfo& experimental_debug_info);
371 };
372
373 // Represents an input of a node, i.e., the `index`-th input to `node`.
374 struct InputTensor {
375 Node* node;
376 int index;
377
InputTensorInputTensor378 InputTensor(Node* n, int i) : node(n), index(i) {}
InputTensorInputTensor379 InputTensor() : node(nullptr), index(0) {}
380
381 // Returns true if this InputTensor is identical to 'other'. Nodes are
382 // compared using pointer equality.
383 bool operator==(const InputTensor& other) const;
384
385 // A hash function for InputTensors. Nodes are hashed based on their pointer
386 // value.
387 struct Hash {
388 uint64 operator()(InputTensor const& s) const;
389 };
390 };
391
392 // Represents an output of a node, i.e., the `index`-th output of `node`. Note
393 // that a single `OutputTensor` can correspond to multiple `Edge`s if the output
394 // is consumed by multiple destination nodes.
395 struct OutputTensor {
396 Node* node;
397 int index;
398
OutputTensorOutputTensor399 OutputTensor(Node* n, int i) : node(n), index(i) {}
OutputTensorOutputTensor400 OutputTensor() : node(nullptr), index(0) {}
401
402 // Returns true if this OutputTensor is identical to 'other'. Nodes are
403 // compared using pointer equality.
404 bool operator==(const OutputTensor& other) const;
405
406 // A hash function for OutputTensors. Nodes are hashed based on their pointer
407 // value.
408 struct Hash {
409 uint64 operator()(OutputTensor const& s) const;
410 };
411 };
412
413 class Edge {
414 public:
src()415 Node* src() const { return src_; }
dst()416 Node* dst() const { return dst_; }
id()417 int id() const { return id_; }
418
419 // Return the index of the source output that produces the data
420 // carried by this edge. The special value kControlSlot is used
421 // for control dependencies.
src_output()422 int src_output() const { return src_output_; }
423
424 // Return the index of the destination input that consumes the data
425 // carried by this edge. The special value kControlSlot is used
426 // for control dependencies.
dst_input()427 int dst_input() const { return dst_input_; }
428
429 // Return true iff this is an edge that indicates a control-flow
430 // (as opposed to a data-flow) dependency.
431 bool IsControlEdge() const;
432
433 std::string DebugString() const;
434
435 private:
Edge()436 Edge() {}
437
438 friend class EdgeSetTest;
439 friend class Graph;
440 Node* src_;
441 Node* dst_;
442 int id_;
443 int src_output_;
444 int dst_input_;
445 };
446
447 // Allows for iteration of the edges of a Graph, by iterating the underlying
448 // Graph.edges_ vector while skipping over null entries.
449 class GraphEdgesIterable {
450 private:
451 const std::vector<Edge*>& edges_;
452
453 public:
GraphEdgesIterable(const std::vector<Edge * > & edges)454 explicit GraphEdgesIterable(const std::vector<Edge*>& edges)
455 : edges_(edges) {}
456
457 typedef Edge* value_type;
458
459 class const_iterator {
460 private:
461 // The underlying iterator.
462 std::vector<value_type>::const_iterator iter_;
463
464 // The end of the underlying iterator.
465 std::vector<value_type>::const_iterator end_;
466
467 // Advances iter_ until it reaches a non-null item, or reaches the end.
apply_filter()468 void apply_filter() {
469 while (iter_ != end_ && *iter_ == nullptr) {
470 ++iter_;
471 }
472 }
473
474 public:
const_iterator(std::vector<value_type>::const_iterator iter,std::vector<value_type>::const_iterator end)475 const_iterator(std::vector<value_type>::const_iterator iter,
476 std::vector<value_type>::const_iterator end)
477 : iter_(iter), end_(end) {
478 apply_filter();
479 }
480
481 bool operator==(const const_iterator& other) const {
482 return iter_ == other.iter_;
483 }
484
485 bool operator!=(const const_iterator& other) const {
486 return iter_ != other.iter_;
487 }
488
489 // This is the prefix increment operator (++x), which is the operator
490 // used by C++ range iteration (for (x : y) ...). We intentionally do not
491 // provide a postfix increment operator.
492 const_iterator& operator++() {
493 ++iter_;
494 apply_filter();
495 return *this;
496 }
497
498 value_type operator*() { return *iter_; }
499 };
500
begin()501 const_iterator begin() {
502 return const_iterator(edges_.begin(), edges_.end());
503 }
end()504 const_iterator end() { return const_iterator(edges_.end(), edges_.end()); }
505 };
506
507 // Thread compatible but not thread safe.
508 class Graph {
509 public:
510 // Constructs a graph with a single SOURCE (always id kSourceId) and a
511 // single SINK (always id kSinkId) node, and an edge from SOURCE->SINK.
512 //
513 // The graph can hold ops found in the registry. `ops`s lifetime must be at
514 // least that of the constructed graph's.
515 explicit Graph(const OpRegistryInterface* ops);
516
517 // Constructs a graph with a single SOURCE (always id kSourceId) and a
518 // single SINK (always id kSinkId) node, and an edge from SOURCE->SINK.
519 //
520 // The graph can hold ops found in `flib_def`. Unlike the constructor taking
521 // an OpRegistryInterface, this constructor copies the function definitions in
522 // `flib_def` so its lifetime may be shorter than that of the graph's. The
523 // OpRegistryInterface backing `flib_def` must still have the lifetime of the
524 // graph though.
525 explicit Graph(const FunctionLibraryDefinition& flib_def);
526
527 ~Graph();
528
529 // Clone the current graph into a new one.
530 std::unique_ptr<Graph> Clone();
531
532 static const int kControlSlot;
533
534 // The GraphDef version range of this graph (see graph.proto).
535 const VersionDef& versions() const;
536 void set_versions(const VersionDef& versions);
537
538 // Adds a new node to this graph, and returns it. Infers the Op and
539 // input/output types for the node. *this owns the returned instance.
540 // Returns nullptr and sets *status on error.
541 Node* AddNode(NodeDef node_def, Status* status);
542
543 // Copies *node, which may belong to another graph, to a new node,
544 // which is returned. Does not copy any edges. *this owns the
545 // returned instance.
546 Node* CopyNode(const Node* node);
547
548 // Removes a node from this graph, including all edges from or to it.
549 // *node should not be accessed after calling this function.
550 // REQUIRES: node->IsOp()
551 void RemoveNode(Node* node);
552
553 void Copy(const Graph& src);
554
555 // Adds an edge that connects the xth output of `source` to the yth input of
556 // `dest` and returns it. Does not update dest's NodeDef.
557 const Edge* AddEdge(Node* source, int x, Node* dest, int y);
558
559 // Adds a control edge (no data flows along this edge) that connects `source`
560 // to `dest`. If `dest`s NodeDef is missing the corresponding control input,
561 // adds the control input.
562 //
563 // If such a control edge already exists and `allow_duplicates` is false, no
564 // edge is added and the function returns nullptr. Otherwise the edge is
565 // unconditionally created and returned. The NodeDef is not updated if
566 // `allow_duplicates` is true.
567 // TODO(skyewm): // TODO(skyewm): allow_duplicates is needed only by
568 // graph_partition.cc. Figure out if we can do away with it.
569 const Edge* AddControlEdge(Node* source, Node* dest,
570 bool allow_duplicates = false);
571
572 // Removes edge from the graph. Does not update the destination node's
573 // NodeDef.
574 // REQUIRES: The edge must exist.
575 void RemoveEdge(const Edge* edge);
576
577 // Removes control edge `edge` from the graph. Note that this also updates
578 // the corresponding NodeDef to reflect the change.
579 // REQUIRES: The control edge must exist.
580 void RemoveControlEdge(const Edge* e);
581
582 // Updates the input to a node. The existing edge to `dst` is removed and an
583 // edge from `new_src` to `dst` is created. The NodeDef associated with `dst`
584 // is also updated.
585 Status UpdateEdge(Node* new_src, int new_src_index, Node* dst, int dst_index);
586
587 // Like AddEdge but updates dst's NodeDef. Used to add an input edge to a
588 // "While" op during gradient construction, see AddInputWhileHack in
589 // python_api.h for more details.
590 Status AddWhileInputHack(Node* new_src, int new_src_index, Node* dst);
591
592 // Adds the function and gradient definitions in `fdef_lib` to this graph's op
593 // registry. Ignores duplicate functions, and returns a bad status if an
594 // imported function differs from an existing function or op with the same
595 // name.
596 Status AddFunctionLibrary(const FunctionDefLibrary& fdef_lib);
597
598 // The number of live nodes in the graph.
599 //
600 // Because nodes can be removed from the graph, num_nodes() is often
601 // smaller than num_node_ids(). If one needs to create an array of
602 // nodes indexed by node ids, num_node_ids() should be used as the
603 // array's size.
num_nodes()604 int num_nodes() const { return num_nodes_; }
605
606 // The number of live nodes in the graph, excluding the Source and Sink nodes.
num_op_nodes()607 int num_op_nodes() const {
608 DCHECK_GE(num_nodes_, 2);
609 return num_nodes_ - 2;
610 }
611
612 // The number of live edges in the graph.
613 //
614 // Because edges can be removed from the graph, num_edges() is often
615 // smaller than num_edge_ids(). If one needs to create an array of
616 // edges indexed by edge ids, num_edge_ids() should be used as the
617 // array's size.
num_edges()618 int num_edges() const { return num_edges_; }
619
620 // Serialize the nodes starting at `from_node_id` to a GraphDef.
621 void ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const;
622
623 // Serialize to a GraphDef.
624 void ToGraphDef(GraphDef* graph_def) const;
625
626 // This version can be called from debugger to inspect the graph content.
627 // Use the previous version outside debug context for efficiency reasons.
628 //
629 // Note: We do not expose a DebugString() API, since GraphDef.DebugString() is
630 // not defined in some TensorFlow builds.
631 GraphDef ToGraphDefDebug() const;
632
633 // Generate new node name with the specified prefix that is unique
634 // across this graph.
635 std::string NewName(StringPiece prefix);
636
637 // Access to the list of all nodes. Example usage:
638 // for (Node* node : graph.nodes()) { ... }
639 gtl::iterator_range<NodeIter> nodes() const;
640
641 // Access to the list of all nodes, excluding the Source and Sink nodes.
642 gtl::iterator_range<NodeIter> op_nodes() const;
643
644 // Returns one more than the maximum id assigned to any node.
num_node_ids()645 int num_node_ids() const { return nodes_.size(); }
646
647 // Returns the node associated with an id, or nullptr if no node
648 // with that id (the node with that id was removed and the id has
649 // not yet been re-used). *this owns the returned instance.
650 // REQUIRES: 0 <= id < num_node_ids().
FindNodeId(int id)651 Node* FindNodeId(int id) const { return nodes_[id]; }
652
653 // Returns one more than the maximum id assigned to any edge.
num_edge_ids()654 int num_edge_ids() const { return edges_.size(); }
655
656 // Returns the Edge associated with an id, or nullptr if no edge
657 // with that id (the edge with that id was removed and the id has
658 // not yet been re-used). *this owns the returned instance.
659 // REQUIRES: 0 <= id < num_edge_ids().
FindEdgeId(int id)660 const Edge* FindEdgeId(int id) const { return edges_[id]; }
661
662 // Access to the set of all edges. Example usage:
663 // for (const Edge* e : graph.edges()) { ... }
edges()664 GraphEdgesIterable edges() const { return GraphEdgesIterable(edges_); }
665
666 // The pre-defined nodes.
667 enum { kSourceId = 0, kSinkId = 1 };
source_node()668 Node* source_node() const { return FindNodeId(kSourceId); }
sink_node()669 Node* sink_node() const { return FindNodeId(kSinkId); }
670
op_registry()671 const OpRegistryInterface* op_registry() const { return &ops_; }
flib_def()672 const FunctionLibraryDefinition& flib_def() const { return ops_; }
673
674 // TODO(mdan): This is only used by control_flow_deps_o_chains. Remove?
mutable_flib_def()675 FunctionLibraryDefinition* mutable_flib_def() { return &ops_; }
676
CheckDeviceNameIndex(int index)677 void CheckDeviceNameIndex(int index) {
678 DCHECK_GE(index, 0);
679 DCHECK_LT(index, static_cast<int>(device_names_.size()));
680 }
681
682 int InternDeviceName(const std::string& device_name);
683
get_assigned_device_name(const Node & node)684 const std::string& get_assigned_device_name(const Node& node) const {
685 return device_names_[node.assigned_device_name_index()];
686 }
687
set_assigned_device_name_index(Node * node,int device_name_index)688 void set_assigned_device_name_index(Node* node, int device_name_index) {
689 CheckDeviceNameIndex(device_name_index);
690 node->assigned_device_name_index_ = device_name_index;
691 }
692
set_assigned_device_name(Node * node,const std::string & device_name)693 void set_assigned_device_name(Node* node, const std::string& device_name) {
694 node->assigned_device_name_index_ = InternDeviceName(device_name);
695 }
696
697 // Returns OK if `node` is non-null and belongs to this graph
698 Status IsValidNode(const Node* node) const;
699
700 // Returns OK if IsValidNode(`node`) and `idx` is a valid output. Does not
701 // accept control outputs.
702 Status IsValidOutputTensor(const Node* node, int idx) const;
703
704 // Returns OK if IsValidNode(`node`) and `idx` a valid input. Does not accept
705 // control inputs.
706 Status IsValidInputTensor(const Node* node, int idx) const;
707
708 // Create and return a new WhileContext owned by this graph. This is called
709 // when a new while loop is created. `frame_name` must be unique among
710 // WhileContexts in this graph.
711 Status AddWhileContext(StringPiece frame_name, std::vector<Node*> enter_nodes,
712 std::vector<Node*> exit_nodes,
713 OutputTensor cond_output,
714 std::vector<OutputTensor> body_inputs,
715 std::vector<OutputTensor> body_outputs,
716 WhileContext** result);
717
718 // Builds a node name to node pointer index for all nodes in the graph.
719 std::unordered_map<string, Node*> BuildNodeNameIndex() const;
720
GetConstArgIndicesCache()721 absl::optional<std::vector<bool>>& GetConstArgIndicesCache() const {
722 return const_arg_indices_cache_;
723 }
724
725 // TODO(kkb): Add to the constructor when it becomes managable.
726 // Sets the graph construction context.
SetConstructionContext(ConstructionContext construction_context)727 void SetConstructionContext(ConstructionContext construction_context) {
728 construction_context_ = construction_context;
729 }
730
731 // TODO(kkb): Rename to `GetConstructionContext` once we're comfortable
732 // making this stable and make it available widely.
733 // Returns the graph construction context. It's `kUnknown` if not set.
GetConstructionContextInternal()734 ConstructionContext GetConstructionContextInternal() const {
735 return construction_context_;
736 }
737
738 void SetNodeType(StringPiece name, const FullTypeDef& type);
739
740 void NodeType(StringPiece name, FullTypeDef** result);
741
742 // TODO(josh11b): uint64 hash() const;
743
744 private:
745 // If cost_node is non-null, then cost accounting (in CostModel)
746 // will be associated with that node rather than the new one being
747 // created.
748 //
749 // Ownership of the returned Node is not transferred to caller.
750 Node* AllocateNode(std::shared_ptr<NodeProperties> props,
751 const Node* cost_node, Node::NodeClass node_class);
752 void ReleaseNode(Node* node);
753 // Insert edge in free_edges_ for possible reuse.
754 void RecycleEdge(const Edge* edge);
755 // Registry of all known ops, including functions.
756 FunctionLibraryDefinition ops_;
757
758 // GraphDef versions
759 const std::unique_ptr<VersionDef> versions_;
760
761 // Allocator which will give us good locality.
762 core::Arena arena_;
763
764 // Map from node ids to allocated nodes. nodes_[id] may be nullptr if
765 // the node with that id was removed from the graph.
766 std::vector<Node*> nodes_;
767
768 // Types table.
769 // TODO(mdan): Do not store these here. Instead, keep in a GraphDef field.
770 std::unordered_set<TypeRef, TypeHasher> types_;
771
772 // Experimental.
773 // Map from node node names to their outputs' FullType. Typically, the values
774 // in this map are identical to those in types_, but that is not enforced or
775 // guaranteed.
776 //
777 // The full type specification combines a Tensor's dtype, tensor_shape,
778 // variant_val, etc. into a unified representation.
779 // This definition may only contain concrete types (for example,
780 // Tensor<TypeVar<'T'>> is not a valid node type).
781 //
782 // Presently, FullType duplicates any information found in `dtype`. When set,
783 // it is always consistent with `dtype`. Eventually, `dtype` will be merged
784 // with FullType.
785 //
786 // For example, if a TensorProto has `dtype=DT_INT32`, then
787 // `full_type=FT_TENSOR[FT_INT32]`.
788 // TODO(mdan): Do not store these here. Instead, keep in a GraphDef field.
789 std::unordered_map<string, TypeRef> node_name_to_out_type_;
790
791 // Number of nodes alive.
792 int64 num_nodes_ = 0;
793
794 // Map from edge ids to allocated edges. edges_[id] may be nullptr if
795 // the edge with that id was removed from the graph.
796 std::vector<Edge*> edges_;
797
798 // The number of entries in edges_ that are not nullptr.
799 int num_edges_ = 0;
800
801 // Allocated but free nodes and edges.
802 std::vector<Node*> free_nodes_;
803 std::vector<Edge*> free_edges_;
804
805 // For generating unique names.
806 int name_counter_ = 0;
807
808 // In most graphs, the number of unique values used for the
809 // Node::assigned_device_name() property is quite small. If the graph is
810 // large, then this duplication of values can consume a significant amount of
811 // memory. Instead, we represent the same information using an interning
812 // table, which consists of a vector of unique strings (device_names_), as
813 // well a map (device_names_map_) from unique strings to indices within the
814 // unique string table.
815 //
816 // The InternDeviceName() method handles adding a new entry into the table,
817 // or locating the index of an existing entry.
818 //
819 // The fact that Node::assigned_device_name() is implemented using an
820 // interning table is intentionally public. This allows algorithms that
821 // frequently access this field to do so efficiently, especially for the case
822 // where the assigned_device_name of one Node is copied directly from that
823 // of another Node.
824
825 // A table of the unique assigned device names. Indices do NOT correspond
826 // to node IDs. Index 0 is always the empty string.
827 std::vector<string> device_names_;
828
829 // Maps unique device names to indices within device_names_[i].
830 std::unordered_map<string, int> device_names_map_;
831
832 // All the while contexts owned by this graph, keyed by frame name,
833 // corresponding to all the while loops contained in this graph (including
834 // nested loops). The stored contexts are usually accessed via
835 // AddWhileContext() or Node::while_ctx(), but this manages the lifetime.
836 std::map<string, WhileContext> while_ctxs_;
837
838 // Cache of the indices of the arguments which need to be constant for the XLA
839 // compilation.
840 mutable absl::optional<std::vector<bool>> const_arg_indices_cache_;
841
842 // Indicates the context that this Graph instance is constructed.
843 ConstructionContext construction_context_ = ConstructionContext::kNotTracked;
844
845 TF_DISALLOW_COPY_AND_ASSIGN(Graph);
846 };
847
848 // TODO(josh11b): We may want to support keeping an index on various
849 // node/edge attributes in a graph, particularly node names.
850
851 // Helper routines
852
IsSource(const Node * node)853 inline bool IsSource(const Node* node) { return node->IsSource(); }
IsSink(const Node * node)854 inline bool IsSink(const Node* node) { return node->IsSink(); }
IsSwitch(const Node * node)855 inline bool IsSwitch(const Node* node) { return node->IsSwitch(); }
IsMerge(const Node * node)856 inline bool IsMerge(const Node* node) { return node->IsMerge(); }
IsEnter(const Node * node)857 inline bool IsEnter(const Node* node) { return node->IsEnter(); }
IsExit(const Node * node)858 inline bool IsExit(const Node* node) { return node->IsExit(); }
IsNextIteration(const Node * n)859 inline bool IsNextIteration(const Node* n) { return n->IsNextIteration(); }
IsLoopCond(const Node * node)860 inline bool IsLoopCond(const Node* node) { return node->IsLoopCond(); }
IsControlTrigger(const Node * n)861 inline bool IsControlTrigger(const Node* n) { return n->IsControlTrigger(); }
IsSend(const Node * node)862 inline bool IsSend(const Node* node) { return node->IsSend(); }
IsRecv(const Node * node)863 inline bool IsRecv(const Node* node) { return node->IsRecv(); }
IsHostSend(const Node * node)864 inline bool IsHostSend(const Node* node) { return node->IsHostSend(); }
IsHostRecv(const Node * node)865 inline bool IsHostRecv(const Node* node) { return node->IsHostRecv(); }
866
867 // True for Nodes that mediate the transfer of values between processes.
IsTransferNode(const Node * n)868 inline bool IsTransferNode(const Node* n) { return IsSend(n) || IsRecv(n); }
869
IsConstant(const Node * node)870 inline bool IsConstant(const Node* node) { return node->IsConstant(); }
IsVariable(const Node * node)871 inline bool IsVariable(const Node* node) { return node->IsVariable(); }
IsIdentity(const Node * node)872 inline bool IsIdentity(const Node* node) { return node->IsIdentity(); }
873
874 // Returns true iff 'n' is a control flow node.
IsControlFlow(const Node * n)875 inline bool IsControlFlow(const Node* n) { return n->IsControlFlow(); }
876
877 // Returns true if the node only depends on its input's metadata
878 // (shape). Specifically, returns true for "Size", "Shape" and "Rank" ops.
IsMetadata(const Node * n)879 inline bool IsMetadata(const Node* n) { return n->IsMetadata(); }
880
IsScopedAllocator(const Node * n)881 inline bool IsScopedAllocator(const Node* n) { return n->IsScopedAllocator(); }
882
IsHostMemoryPreserving(const Node * node)883 inline bool IsHostMemoryPreserving(const Node* node) {
884 return IsIdentity(node) || IsControlFlow(node);
885 }
886
IsDistributedCommunication(const Node * n)887 inline bool IsDistributedCommunication(const Node* n) {
888 return n->IsDistributedCommunication();
889 }
890
891 // NOTE: We declare Reference type of NodeIter and NeighborIter as Node* (see
892 // https://en.cppreference.com/w/cpp/iterator/iterator).
893
894 // Iterator for stepping through the nodes of a graph.
895 class NodeIter
896 : public std::iterator<std::forward_iterator_tag, Node, std::ptrdiff_t,
897 /*Pointer*/ Node*, /*Reference*/ Node*> {
898 public:
899 NodeIter(const Graph* graph, int id);
900 bool operator==(const NodeIter& rhs) const;
901 bool operator!=(const NodeIter& rhs) const;
902 void operator++();
903 reference operator*() const;
904 pointer operator->() const;
905
906 private:
907 // Invariant: id_ == graph_->num_node_ids() || graph_->FindId(id_) != nullptr
908 const Graph* graph_;
909 int id_;
910 };
911
912 // Iterator for stepping through the neighbors of a node.
913 class NeighborIter
914 : public std::iterator<std::forward_iterator_tag, Node, std::ptrdiff_t,
915 /*Pointer*/ Node*, /*Reference*/ Node*> {
916 public:
917 NeighborIter(EdgeSet::const_iterator iter, bool incoming);
918 bool operator==(const NeighborIter& rhs) const;
919 bool operator!=(const NeighborIter& rhs) const;
920 void operator++();
921 reference operator*() const;
922 pointer operator->() const;
923
924 private:
925 EdgeSet::const_iterator iter_;
926 bool incoming_;
927 };
928
929 // IMPLEMENTATION DETAILS, PLEASE IGNORE
930
NodeIter(const Graph * graph,int id)931 inline NodeIter::NodeIter(const Graph* graph, int id)
932 : graph_(graph), id_(id) {}
933
934 inline bool NodeIter::operator==(const NodeIter& rhs) const {
935 DCHECK(graph_ == rhs.graph_);
936 return id_ == rhs.id_;
937 }
938
939 inline bool NodeIter::operator!=(const NodeIter& rhs) const {
940 return !(*this == rhs);
941 }
942
943 inline void NodeIter::operator++() {
944 while (1) {
945 DCHECK_LE(id_, graph_->num_node_ids());
946 ++id_;
947 if (id_ >= graph_->num_node_ids() || graph_->FindNodeId(id_) != nullptr) {
948 return;
949 }
950 }
951 }
952
953 inline Node* NodeIter::operator*() const { return graph_->FindNodeId(id_); }
954
955 inline Node* NodeIter::operator->() const { return graph_->FindNodeId(id_); }
956
NeighborIter(EdgeSet::const_iterator iter,bool incoming)957 inline NeighborIter::NeighborIter(EdgeSet::const_iterator iter, bool incoming)
958 : iter_(iter), incoming_(incoming) {}
959
960 inline bool NeighborIter::operator==(const NeighborIter& rhs) const {
961 return iter_ == rhs.iter_ && incoming_ == rhs.incoming_;
962 }
963
964 inline bool NeighborIter::operator!=(const NeighborIter& rhs) const {
965 return !(*this == rhs);
966 }
967
968 inline void NeighborIter::operator++() { ++iter_; }
969
970 inline Node* NeighborIter::operator*() const {
971 const Edge* e = *iter_;
972 return incoming_ ? e->src() : e->dst();
973 }
974
975 inline Node* NeighborIter::operator->() const {
976 const Edge* e = *iter_;
977 return incoming_ ? e->src() : e->dst();
978 }
979
IsControlEdge()980 inline bool Edge::IsControlEdge() const {
981 // Note that if either src_output_ or dst_input_ is kControlSlot,
982 // so is the other one (AddEdge checks this).
983 return src_output_ == Graph::kControlSlot;
984 }
985
nodes()986 inline gtl::iterator_range<NodeIter> Graph::nodes() const {
987 // Note that NodeId 0 is always valid since we don't let the source
988 // node be removed from the graph.
989 return gtl::make_range(NodeIter(this, 0), NodeIter(this, num_node_ids()));
990 }
991
op_nodes()992 inline gtl::iterator_range<NodeIter> Graph::op_nodes() const {
993 // Note that NodeId 0 is always valid since we don't let the source
994 // node be removed from the graph.
995 //
996 // The current implementation of Graph maintains the invariant that the
997 // first two nodes are the source and sink nodes, and all other nodes are op
998 // nodes. This method (op_nodes()) relies on this invariant.
999 NodeIter begin(this, 0);
1000 NodeIter end(this, num_node_ids());
1001 if (begin != end) {
1002 ++begin;
1003 }
1004 if (begin != end) {
1005 ++begin;
1006 }
1007 return gtl::make_range(begin, end);
1008 }
1009
set_assigned_device_name_index(int index)1010 inline void Node::set_assigned_device_name_index(int index) {
1011 graph_->CheckDeviceNameIndex(index);
1012 assigned_device_name_index_ = index;
1013 }
1014
set_assigned_device_name(const std::string & device_name)1015 inline void Node::set_assigned_device_name(const std::string& device_name) {
1016 graph_->set_assigned_device_name(this, device_name);
1017 }
1018
assigned_device_name()1019 inline const std::string& Node::assigned_device_name() const {
1020 return graph_->get_assigned_device_name(*this);
1021 }
1022
1023 } // namespace tensorflow
1024
1025 #endif // TENSORFLOW_CORE_GRAPH_GRAPH_H_
1026