1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // A Graph describes a set of computations that are to be
17 // performed, as well as the dependencies between those
18 // computations. The basic model is a DAG (directed acyclic graph) with
19 // * internal nodes representing computational operations to be performed;
20 // * edges represent dependencies, indicating the target may only be
21 // executed once the source has completed; and
22 // * predefined "source" (start) and "sink" (finish) nodes -- the source
23 // should be the only node that doesn't depend on anything, and the sink
24 // should be the only node that nothing depends on.
25 //
26 // Note: Node ids are intended to be relatively dense in the
27 // 0..max_id range, but there may be gaps since ids won't be reused.
28 //
29 // Note: Some dependencies between operations are due to one operation
30 // consuming the output of another. In fact operations can produce
31 // multiple outputs and consume multiple inputs, and some
32 // optimizations will care about which specific outputs are connected
33 // to which specific inputs. We therefore represent data dependency
34 // between output O of layer A and input I of layer B using
35 // "input index" and "output index" labels per edge.
36
37 #ifndef TENSORFLOW_CORE_GRAPH_GRAPH_H_
38 #define TENSORFLOW_CORE_GRAPH_GRAPH_H_
39
40 #include <functional>
41 #include <memory>
42 #include <string>
43 #include <vector>
44
45 #include "absl/types/optional.h"
46 #include "tensorflow/core/framework/full_type.pb.h"
47 #include "tensorflow/core/framework/function.h"
48 #include "tensorflow/core/framework/node_def.pb.h"
49 #include "tensorflow/core/framework/node_def_util.h"
50 #include "tensorflow/core/framework/op.h"
51 #include "tensorflow/core/framework/types.h"
52 #include "tensorflow/core/graph/edgeset.h"
53 #include "tensorflow/core/lib/core/arena.h"
54 #include "tensorflow/core/lib/core/refcount.h"
55 #include "tensorflow/core/lib/core/status.h"
56 #include "tensorflow/core/lib/gtl/iterator_range.h"
57 #include "tensorflow/core/platform/logging.h"
58 #include "tensorflow/core/platform/macros.h"
59 #include "tensorflow/core/platform/stringpiece.h"
60 #include "tensorflow/core/platform/types.h"
61
62 namespace tensorflow {
63
64 class Edge;
65 class EdgeSetTest;
66 class Graph;
67 class GraphDef;
68 class Node;
69 struct OutputTensor;
70 class VersionDef;
71 class WhileContext;
72
73 class NeighborIter; // Declared below
74 class NodeIter; // Declared below
75
76 // Indicates where the graph instance is originated from.
77 enum class ConstructionContext {
78 kNotTracked, // Not tracked.
79 kDirectSession, // From `tensorflow::DirectSession`, TF1 session API.
80 kEagerRuntime, // Registered from TF2 eager runtime.
81 };
82
83 class Node {
84 public:
85 std::string DebugString() const;
id()86 int id() const { return id_; }
cost_id()87 int cost_id() const { return cost_id_; }
88 const std::string& name() const;
89 void set_name(std::string name);
90 const std::string& type_string() const;
91
92 // def() provides the NodeDef the user supplied, but the specifics
93 // of this Node may have changed due to placement, optimization, etc.
94 // In particular:
95 // * def().name() will match name();
96 // * def().op() will match type_string() and op_def().name();
97 // * def().input() is not reliable, use "in_edges()" below instead;
98 // * def().device() is the "user's requested device" and may not match
99 // the actual assigned device, see assigned_device_name() below;
100 // * def().attr() is authoritative.
101 // TODO(irving): Replace with NodeInfo.
102 const NodeDef& def() const;
103 const OpDef& op_def() const;
104
105 // TODO(mdan): This is only used by control_flow_deps_o_chains. Remove?
106 NodeDef* mutable_def();
107
108 // input and output types
109 int32 num_inputs() const;
110 DataType input_type(int32_t i) const;
111 const DataTypeVector& input_types() const;
112
113 int32 num_outputs() const;
114 DataType output_type(int32_t o) const;
115 const DataTypeVector& output_types() const;
116
117 // The device requested by the user. For the actual assigned device,
118 // use assigned_device_name() below.
119 const std::string& requested_device() const;
120
121 // This changes the user requested device but not necessarily the device that
122 // on which the operation will run.
123 void set_requested_device(const std::string& device);
124
125 // This gives the device the runtime has assigned this node to. If
126 // you want the device the user requested, use def().device() instead.
127 // TODO(josh11b): Validate that the assigned_device, if not empty:
128 // fully specifies a device, and satisfies def().device().
129 // TODO(josh11b): Move assigned_device_name outside of Node into a
130 // NodeId->DeviceName map.
131 const std::string& assigned_device_name() const;
132 void set_assigned_device_name(const std::string& device_name);
has_assigned_device_name()133 bool has_assigned_device_name() const {
134 return assigned_device_name_index_ > 0;
135 }
assigned_device_name_index()136 int assigned_device_name_index() const { return assigned_device_name_index_; }
137 void set_assigned_device_name_index(int index);
138
139 // Sets 'original_node_names' field of this node's DebugInfo proto to
140 // 'names'.
141 void set_original_node_names(const std::vector<string>& names);
142 void set_original_func_names(const std::vector<string>& names);
143
144 // Read only access to attributes
145 AttrSlice attrs() const;
146
147 // Inputs requested by the NodeDef. For the actual inputs, use in_edges.
148 const protobuf::RepeatedPtrField<string>& requested_inputs() const;
149
150 // Get the neighboring nodes via edges either in or out of this node. This
151 // includes control edges.
152 gtl::iterator_range<NeighborIter> in_nodes() const;
153 gtl::iterator_range<NeighborIter> out_nodes() const;
in_edges()154 const EdgeSet& in_edges() const { return in_edges_; }
out_edges()155 const EdgeSet& out_edges() const { return out_edges_; }
156
157 // Node type helpers.
IsSource()158 bool IsSource() const { return id() == 0; }
IsSink()159 bool IsSink() const { return id() == 1; }
160 // Anything other than the special Source & Sink nodes.
IsOp()161 bool IsOp() const { return id() > 1; }
162
163 // Node class helpers
IsSwitch()164 bool IsSwitch() const { return class_ == NC_SWITCH; }
IsMerge()165 bool IsMerge() const { return class_ == NC_MERGE; }
IsEnter()166 bool IsEnter() const { return class_ == NC_ENTER; }
IsExit()167 bool IsExit() const { return class_ == NC_EXIT; }
IsNextIteration()168 bool IsNextIteration() const { return class_ == NC_NEXT_ITERATION; }
IsLoopCond()169 bool IsLoopCond() const { return class_ == NC_LOOP_COND; }
IsControlTrigger()170 bool IsControlTrigger() const { return class_ == NC_CONTROL_TRIGGER; }
IsSend()171 bool IsSend() const { return class_ == NC_SEND || class_ == NC_HOST_SEND; }
IsRecv()172 bool IsRecv() const { return class_ == NC_RECV || class_ == NC_HOST_RECV; }
IsConstant()173 bool IsConstant() const { return class_ == NC_CONSTANT; }
IsVariable()174 bool IsVariable() const { return class_ == NC_VARIABLE; }
IsIdentity()175 bool IsIdentity() const { return class_ == NC_IDENTITY; }
IsGetSessionHandle()176 bool IsGetSessionHandle() const { return class_ == NC_GET_SESSION_HANDLE; }
IsGetSessionTensor()177 bool IsGetSessionTensor() const { return class_ == NC_GET_SESSION_TENSOR; }
IsDeleteSessionTensor()178 bool IsDeleteSessionTensor() const {
179 return class_ == NC_DELETE_SESSION_TENSOR;
180 }
IsControlFlow()181 bool IsControlFlow() const {
182 return (class_ != NC_OTHER) && // Fast path
183 (IsSwitch() || IsMerge() || IsEnter() || IsExit() ||
184 IsNextIteration());
185 }
IsHostSend()186 bool IsHostSend() const { return class_ == NC_HOST_SEND; }
IsHostRecv()187 bool IsHostRecv() const { return class_ == NC_HOST_RECV; }
IsScopedAllocator()188 bool IsScopedAllocator() const { return class_ == NC_SCOPED_ALLOCATOR; }
IsCollective()189 bool IsCollective() const { return class_ == NC_COLLECTIVE; }
190
IsMetadata()191 bool IsMetadata() const { return class_ == NC_METADATA; }
IsFakeParam()192 bool IsFakeParam() const { return class_ == NC_FAKE_PARAM; }
IsPartitionedCall()193 bool IsPartitionedCall() const { return class_ == NC_PARTITIONED_CALL; }
194
195 // Returns true if this node is any kind of function call node.
196 //
197 // NOTE: "function call nodes" include partitioned call ops, symbolic gradient
198 // ops, and ops whose type_string is the name of a function ("function ops").
IsFunctionCall()199 bool IsFunctionCall() const {
200 return class_ == NC_PARTITIONED_CALL || class_ == NC_FUNCTION_OP ||
201 class_ == NC_SYMBOLIC_GRADIENT;
202 }
203
IsIfNode()204 bool IsIfNode() const { return class_ == NC_IF; }
IsWhileNode()205 bool IsWhileNode() const { return class_ == NC_WHILE; }
IsCaseNode()206 bool IsCaseNode() const { return class_ == NC_CASE; }
207 // Is this node a function input
IsArg()208 bool IsArg() const { return class_ == NC_ARG; }
209 // Is this node a function output
IsRetval()210 bool IsRetval() const { return class_ == NC_RETVAL; }
211
IsDistributedCommunication()212 bool IsDistributedCommunication() const {
213 return op_def().is_distributed_communication();
214 }
215
216 template <typename T>
AddAttr(const std::string & name,const T & val)217 void AddAttr(const std::string& name, const T& val) {
218 SetAttrValue(val, AddAttrHelper(name));
219 UpdateProperties();
220 }
221
AddAttr(const std::string & name,std::vector<string> && val)222 void AddAttr(const std::string& name, std::vector<string>&& val) {
223 MoveAttrValue(std::move(val), AddAttrHelper(name));
224 UpdateProperties();
225 }
226
227 void ClearAttr(const std::string& name);
228
229 // Returns into '*e' the edge connecting to the 'idx' input of this Node.
230 Status input_edge(int idx, const Edge** e) const;
231
232 // Returns into '*edges' the input data edges of this Node, indexed by input
233 // number. Does not return control edges.
234 Status input_edges(std::vector<const Edge*>* edges) const;
235
236 // Returns into '*n' the node that has an output connected to the
237 // 'idx' input of this Node.
238 Status input_node(int idx, const Node** n) const;
239 Status input_node(int idx, Node** n) const;
240
241 // Returns into '*t' the idx-th input tensor of this node, represented as the
242 // output tensor of input_node(idx).
243 Status input_tensor(int idx, OutputTensor* t) const;
244
while_ctx()245 WhileContext* while_ctx() const { return while_ctx_; }
set_while_ctx(WhileContext * while_ctx)246 void set_while_ctx(WhileContext* while_ctx) {
247 DCHECK(IsExit());
248 DCHECK(while_ctx_ == nullptr);
249 while_ctx_ = while_ctx;
250 }
251
properties()252 std::shared_ptr<NodeProperties> properties() const { return props_; }
253
254 // Sets the stack trace for the node. Assumes that getting and setting the
255 // stack trace for a given node will not race.
SetStackTrace(const std::shared_ptr<AbstractStackTrace> & stack_trace)256 void SetStackTrace(const std::shared_ptr<AbstractStackTrace>& stack_trace) {
257 stack_trace_ = stack_trace;
258 }
259
260 // Get the stack trace for when the node was instantiated.
GetStackTrace()261 const std::shared_ptr<AbstractStackTrace>& GetStackTrace() const {
262 return stack_trace_;
263 }
264
265 // Called after an attr has changed. Decides whether we need to update some
266 // property of the node (stored in props_).
267 void UpdateProperties();
268
269 // Erases type information from the node.
270 void ClearTypeInfo();
271
272 // Called after an incident non-control edge has changed. Does nothing if not
273 // all input edges are defined.
274 void RunForwardTypeInference();
275
276 private:
277 // TODO(mdan): Drop this.
278 friend class Graph;
279 Node();
280
281 // Stack trace for the user code for node instantiation. Can be shared across
282 // multiple nodes (e.g. when inlining).
283 std::shared_ptr<AbstractStackTrace> stack_trace_;
284
285 // Releases memory from props_, in addition to restoring *this to its
286 // uninitialized state.
287 void Clear();
288
289 // Make a copy of the Node's props_ if props_ is shared with
290 // other nodes. This must be called before mutating properties,
291 // e.g. in AddAttr.
292 void MaybeCopyOnWrite();
293
294 AttrValue* AddAttrHelper(const std::string& name);
295
296 // A set of mutually exclusive classes for different kinds of nodes,
297 // class_ is initialized in the Node::Initialize routine based on the
298 // node's type_string().
299 enum NodeClass {
300 NC_UNINITIALIZED,
301 NC_SWITCH,
302 NC_MERGE,
303 NC_ENTER,
304 NC_EXIT,
305 NC_NEXT_ITERATION,
306 NC_LOOP_COND,
307 NC_CONTROL_TRIGGER,
308 NC_SEND,
309 NC_HOST_SEND,
310 NC_RECV,
311 NC_HOST_RECV,
312 NC_CONSTANT,
313 NC_VARIABLE,
314 NC_IDENTITY,
315 NC_GET_SESSION_HANDLE,
316 NC_GET_SESSION_TENSOR,
317 NC_DELETE_SESSION_TENSOR,
318 NC_METADATA,
319 NC_SCOPED_ALLOCATOR,
320 NC_COLLECTIVE,
321 NC_FAKE_PARAM,
322 NC_PARTITIONED_CALL,
323 NC_FUNCTION_OP,
324 NC_SYMBOLIC_GRADIENT,
325 NC_IF,
326 NC_WHILE,
327 NC_CASE,
328 NC_ARG,
329 NC_RETVAL,
330 NC_OTHER // Not a special kind of node
331 };
332
333 void Initialize(int id, int cost_id, std::shared_ptr<NodeProperties> props,
334 NodeClass node_class);
335
336 static NodeClass GetNodeClassForOp(const std::string& ts);
337
338 int id_; // -1 until Initialize() is called
339 int cost_id_; // -1 if there is no corresponding cost accounting node
340 NodeClass class_;
341
342 EdgeSet in_edges_;
343 EdgeSet out_edges_;
344
345 // NOTE(skyewm): inheriting from core::RefCounted may have a slight
346 // performance benefit over using shared_ptr, at the cost of manual ref
347 // counting
348 std::shared_ptr<NodeProperties> props_;
349
350 // Index within Graph::device_names_ of the name of device assigned
351 // to perform this computation.
352 int assigned_device_name_index_;
353
354 // A back-pointer to the Graph that owns this node. Currently, this exists
355 // solely to allow Node::[set_]assigned_device_name() to work. However, if all
356 // callers of Node::[set_]assigned_device_name() are modified to use the
357 // equivalent methods defined directly on Graph, then we can remove this
358 // field and reclaim that memory.
359 Graph* graph_;
360
361 // Set if this is an exit node of a while loop with an associated
362 // WhileContext. Otherwise null. (This is only set for exit nodes because
363 // they're the first nodes of a loop encountered while creating the gradient
364 // graph. Exit nodes that are part of while loop gradient graphs will not have
365 // this set.)
366 WhileContext* while_ctx_;
367
368 TF_DISALLOW_COPY_AND_ASSIGN(Node);
369 };
370
371 // Stores debug information associated with the Node.
372 struct NodeDebugInfo {
373 const std::string name;
374 std::vector<string> original_node_names;
375 std::vector<string> original_func_names;
376
377 NodeDebugInfo(const Node& n);
378 NodeDebugInfo(const NodeDef& ndef);
379 NodeDebugInfo(StringPiece node_name, bool has_experimental_debug_info,
380 const NodeDef_ExperimentalDebugInfo& experimental_debug_info);
381 };
382
383 // Represents an input of a node, i.e., the `index`-th input to `node`.
384 struct InputTensor {
385 Node* node;
386 int index;
387
InputTensorInputTensor388 InputTensor(Node* n, int i) : node(n), index(i) {}
InputTensorInputTensor389 InputTensor() : node(nullptr), index(0) {}
390
391 // Returns true if this InputTensor is identical to 'other'. Nodes are
392 // compared using pointer equality.
393 bool operator==(const InputTensor& other) const;
394
395 // A hash function for InputTensors. Nodes are hashed based on their pointer
396 // value.
397 struct Hash {
398 uint64 operator()(InputTensor const& s) const;
399 };
400 };
401
402 // Represents an output of a node, i.e., the `index`-th output of `node`. Note
403 // that a single `OutputTensor` can correspond to multiple `Edge`s if the output
404 // is consumed by multiple destination nodes.
405 struct OutputTensor {
406 Node* node;
407 int index;
408
OutputTensorOutputTensor409 OutputTensor(Node* n, int i) : node(n), index(i) {}
OutputTensorOutputTensor410 OutputTensor() : node(nullptr), index(0) {}
411
412 // Returns true if this OutputTensor is identical to 'other'. Nodes are
413 // compared using pointer equality.
414 bool operator==(const OutputTensor& other) const;
415
416 // A hash function for OutputTensors. Nodes are hashed based on their pointer
417 // value.
418 struct Hash {
419 uint64 operator()(OutputTensor const& s) const;
420 };
421 };
422
423 class Edge {
424 public:
src()425 Node* src() const { return src_; }
dst()426 Node* dst() const { return dst_; }
id()427 int id() const { return id_; }
428
429 // Return the index of the source output that produces the data
430 // carried by this edge. The special value kControlSlot is used
431 // for control dependencies.
src_output()432 int src_output() const { return src_output_; }
433
434 // Return the index of the destination input that consumes the data
435 // carried by this edge. The special value kControlSlot is used
436 // for control dependencies.
dst_input()437 int dst_input() const { return dst_input_; }
438
439 // Return true iff this is an edge that indicates a control-flow
440 // (as opposed to a data-flow) dependency.
441 bool IsControlEdge() const;
442
443 std::string DebugString() const;
444
445 private:
Edge()446 Edge() {}
447
448 friend class EdgeSetTest;
449 friend class Graph;
450 Node* src_;
451 Node* dst_;
452 int id_;
453 int src_output_;
454 int dst_input_;
455 };
456
457 // Allows for iteration of the edges of a Graph, by iterating the underlying
458 // Graph.edges_ vector while skipping over null entries.
459 class GraphEdgesIterable {
460 private:
461 const std::vector<Edge*>& edges_;
462
463 public:
GraphEdgesIterable(const std::vector<Edge * > & edges)464 explicit GraphEdgesIterable(const std::vector<Edge*>& edges)
465 : edges_(edges) {}
466
467 typedef Edge* value_type;
468
469 class const_iterator {
470 private:
471 // The underlying iterator.
472 std::vector<value_type>::const_iterator iter_;
473
474 // The end of the underlying iterator.
475 std::vector<value_type>::const_iterator end_;
476
477 // Advances iter_ until it reaches a non-null item, or reaches the end.
apply_filter()478 void apply_filter() {
479 while (iter_ != end_ && *iter_ == nullptr) {
480 ++iter_;
481 }
482 }
483
484 public:
const_iterator(std::vector<value_type>::const_iterator iter,std::vector<value_type>::const_iterator end)485 const_iterator(std::vector<value_type>::const_iterator iter,
486 std::vector<value_type>::const_iterator end)
487 : iter_(iter), end_(end) {
488 apply_filter();
489 }
490
491 bool operator==(const const_iterator& other) const {
492 return iter_ == other.iter_;
493 }
494
495 bool operator!=(const const_iterator& other) const {
496 return iter_ != other.iter_;
497 }
498
499 // This is the prefix increment operator (++x), which is the operator
500 // used by C++ range iteration (for (x : y) ...). We intentionally do not
501 // provide a postfix increment operator.
502 const_iterator& operator++() {
503 ++iter_;
504 apply_filter();
505 return *this;
506 }
507
508 value_type operator*() { return *iter_; }
509 };
510
begin()511 const_iterator begin() {
512 return const_iterator(edges_.begin(), edges_.end());
513 }
end()514 const_iterator end() { return const_iterator(edges_.end(), edges_.end()); }
515 };
516
517 // Thread compatible but not thread safe.
518 class Graph {
519 public:
520 // Constructs a graph with a single SOURCE (always id kSourceId) and a
521 // single SINK (always id kSinkId) node, and an edge from SOURCE->SINK.
522 //
523 // The graph can hold ops found in the registry. `ops`s lifetime must be at
524 // least that of the constructed graph's.
525 explicit Graph(const OpRegistryInterface* ops);
526
527 // Constructs a graph with a single SOURCE (always id kSourceId) and a
528 // single SINK (always id kSinkId) node, and an edge from SOURCE->SINK.
529 //
530 // The graph can hold ops found in `flib_def`. Unlike the constructor taking
531 // an OpRegistryInterface, this constructor copies the function definitions in
532 // `flib_def` so its lifetime may be shorter than that of the graph's. The
533 // OpRegistryInterface backing `flib_def` must still have the lifetime of the
534 // graph though.
535 explicit Graph(const FunctionLibraryDefinition& flib_def);
536
537 ~Graph();
538
539 // Clone the current graph into a new one.
540 std::unique_ptr<Graph> Clone();
541
542 static const int kControlSlot;
543
544 // The GraphDef version range of this graph (see graph.proto).
545 const VersionDef& versions() const;
546 void set_versions(const VersionDef& versions);
547
548 // Adds a new node to this graph, and returns it. Infers the Op and
549 // input/output types for the node. *this owns the returned instance.
550 // Returns nullptr and sets *status on error.
551 Node* AddNode(NodeDef node_def, Status* status);
552
553 // Same as above, but using StatusOr. This method is always preferred.
554 StatusOr<Node*> AddNode(NodeDef node_def);
555
556 // Copies *node, which may belong to another graph, to a new node,
557 // which is returned. Does not copy any edges. *this owns the
558 // returned instance.
559 Node* CopyNode(const Node* node);
560
561 // Removes a node from this graph, including all edges from or to it.
562 // *node should not be accessed after calling this function.
563 // REQUIRES: node->IsOp()
564 void RemoveNode(Node* node);
565
566 void Copy(const Graph& src);
567
568 // Removes all nodes from this graph, including all edges from or to them.
569 // No Node* references to the Graph are valid post.
570 void Clear();
571
572 // Adds an edge that connects the xth output of `source` to the yth input of
573 // `dest` and returns it. Does not update dest's NodeDef.
574 const Edge* AddEdge(Node* source, int x, Node* dest, int y);
575
576 // Adds a control edge (no data flows along this edge) that connects `source`
577 // to `dest`. If `dest`s NodeDef is missing the corresponding control input,
578 // adds the control input.
579 //
580 // If such a control edge already exists and `allow_duplicates` is false, no
581 // edge is added and the function returns nullptr. Otherwise the edge is
582 // unconditionally created and returned. The NodeDef is not updated if
583 // `allow_duplicates` is true.
584 // TODO(skyewm): // TODO(skyewm): allow_duplicates is needed only by
585 // graph_partition.cc. Figure out if we can do away with it.
586 const Edge* AddControlEdge(Node* source, Node* dest,
587 bool allow_duplicates = false);
588
589 // Removes edge from the graph. Does not update the destination node's
590 // NodeDef.
591 // REQUIRES: The edge must exist.
592 void RemoveEdge(const Edge* edge);
593
594 // Removes control edge `edge` from the graph. Note that this also updates
595 // the corresponding NodeDef to reflect the change.
596 // REQUIRES: The control edge must exist.
597 void RemoveControlEdge(const Edge* e);
598
599 // Updates the input to a node. The existing edge to `dst` is removed and an
600 // edge from `new_src` to `dst` is created. The NodeDef associated with `dst`
601 // is also updated.
602 Status UpdateEdge(Node* new_src, int new_src_index, Node* dst, int dst_index);
603
604 // Like AddEdge but updates dst's NodeDef. Used to add an input edge to a
605 // "While" op during gradient construction, see AddInputWhileHack in
606 // python_api.h for more details.
607 Status AddWhileInputHack(Node* new_src, int new_src_index, Node* dst);
608
609 // Adds the function and gradient definitions in `fdef_lib` to this graph's op
610 // registry. Ignores duplicate functions, and returns a bad status if an
611 // imported function differs from an existing function or op with the same
612 // name.
613 Status AddFunctionLibrary(const FunctionDefLibrary& fdef_lib);
614
615 // The number of live nodes in the graph.
616 //
617 // Because nodes can be removed from the graph, num_nodes() is often
618 // smaller than num_node_ids(). If one needs to create an array of
619 // nodes indexed by node ids, num_node_ids() should be used as the
620 // array's size.
num_nodes()621 int num_nodes() const { return num_nodes_; }
622
623 // The number of live nodes in the graph, excluding the Source and Sink nodes.
num_op_nodes()624 int num_op_nodes() const {
625 DCHECK_GE(num_nodes_, 2);
626 return num_nodes_ - 2;
627 }
628
629 // The number of live edges in the graph.
630 //
631 // Because edges can be removed from the graph, num_edges() is often
632 // smaller than num_edge_ids(). If one needs to create an array of
633 // edges indexed by edge ids, num_edge_ids() should be used as the
634 // array's size.
num_edges()635 int num_edges() const { return num_edges_; }
636
637 // Serialize the nodes starting at `from_node_id` to a GraphDef.
638 void ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const;
639
640 // Serialize to a GraphDef.
641 void ToGraphDef(GraphDef* graph_def) const;
642
643 // This version can be called from debugger to inspect the graph content.
644 // Use the previous version outside debug context for efficiency reasons.
645 //
646 // Note: We do not expose a DebugString() API, since GraphDef.DebugString() is
647 // not defined in some TensorFlow builds.
648 GraphDef ToGraphDefDebug() const;
649
650 // Generate new node name with the specified prefix that is unique
651 // across this graph.
652 std::string NewName(StringPiece prefix);
653
654 // Access to the list of all nodes. Example usage:
655 // for (Node* node : graph.nodes()) { ... }
656 gtl::iterator_range<NodeIter> nodes() const;
657
658 // Access to the list of all nodes, excluding the Source and Sink nodes.
659 gtl::iterator_range<NodeIter> op_nodes() const;
660
661 // Returns one more than the maximum id assigned to any node.
num_node_ids()662 int num_node_ids() const { return nodes_.size(); }
663
664 // Returns the node associated with an id, or nullptr if no node
665 // with that id (the node with that id was removed and the id has
666 // not yet been re-used). *this owns the returned instance.
667 // REQUIRES: 0 <= id < num_node_ids().
FindNodeId(int id)668 Node* FindNodeId(int id) const { return nodes_[id]; }
669
670 // Returns one more than the maximum id assigned to any edge.
num_edge_ids()671 int num_edge_ids() const { return edges_.size(); }
672
673 // Returns the Edge associated with an id, or nullptr if no edge
674 // with that id (the edge with that id was removed and the id has
675 // not yet been re-used). *this owns the returned instance.
676 // REQUIRES: 0 <= id < num_edge_ids().
FindEdgeId(int id)677 const Edge* FindEdgeId(int id) const { return edges_[id]; }
678
679 // Access to the set of all edges. Example usage:
680 // for (const Edge* e : graph.edges()) { ... }
edges()681 GraphEdgesIterable edges() const { return GraphEdgesIterable(edges_); }
682
683 // The pre-defined nodes.
684 enum { kSourceId = 0, kSinkId = 1 };
source_node()685 Node* source_node() const { return FindNodeId(kSourceId); }
sink_node()686 Node* sink_node() const { return FindNodeId(kSinkId); }
687
op_registry()688 const OpRegistryInterface* op_registry() const { return &ops_; }
flib_def()689 const FunctionLibraryDefinition& flib_def() const { return ops_; }
690
691 // TODO(mdan): This is only used by control_flow_deps_o_chains. Remove?
mutable_flib_def()692 FunctionLibraryDefinition* mutable_flib_def() { return &ops_; }
693
CheckDeviceNameIndex(int index)694 void CheckDeviceNameIndex(int index) {
695 DCHECK_GE(index, 0);
696 DCHECK_LT(index, static_cast<int>(device_names_.size()));
697 }
698
699 int InternDeviceName(const std::string& device_name);
700
get_assigned_device_name(const Node & node)701 const std::string& get_assigned_device_name(const Node& node) const {
702 return device_names_[node.assigned_device_name_index()];
703 }
704
set_assigned_device_name_index(Node * node,int device_name_index)705 void set_assigned_device_name_index(Node* node, int device_name_index) {
706 CheckDeviceNameIndex(device_name_index);
707 node->assigned_device_name_index_ = device_name_index;
708 }
709
set_assigned_device_name(Node * node,const std::string & device_name)710 void set_assigned_device_name(Node* node, const std::string& device_name) {
711 node->assigned_device_name_index_ = InternDeviceName(device_name);
712 }
713
714 // Returns OK if `node` is non-null and belongs to this graph
715 Status IsValidNode(const Node* node) const;
716
717 // Returns OK if IsValidNode(`node`) and `idx` is a valid output. Does not
718 // accept control outputs.
719 Status IsValidOutputTensor(const Node* node, int idx) const;
720
721 // Returns OK if IsValidNode(`node`) and `idx` a valid input. Does not accept
722 // control inputs.
723 Status IsValidInputTensor(const Node* node, int idx) const;
724
725 // Create and return a new WhileContext owned by this graph. This is called
726 // when a new while loop is created. `frame_name` must be unique among
727 // WhileContexts in this graph.
728 Status AddWhileContext(StringPiece frame_name, std::vector<Node*> enter_nodes,
729 std::vector<Node*> exit_nodes,
730 OutputTensor cond_output,
731 std::vector<OutputTensor> body_inputs,
732 std::vector<OutputTensor> body_outputs,
733 WhileContext** result);
734
735 // Builds a node name to node pointer index for all nodes in the graph.
736 std::unordered_map<string, Node*> BuildNodeNameIndex() const;
737
GetConstArgIndicesCache()738 absl::optional<std::vector<bool>>& GetConstArgIndicesCache() const {
739 return const_arg_indices_cache_;
740 }
741
742 // TODO(kkb): Add to the constructor when it becomes managable.
743 // Sets the graph construction context.
SetConstructionContext(ConstructionContext construction_context)744 void SetConstructionContext(ConstructionContext construction_context) {
745 construction_context_ = construction_context;
746 }
747
748 // TODO(kkb): Rename to `GetConstructionContext` once we're comfortable
749 // making this stable and make it available widely.
750 // Returns the graph construction context. It's `kUnknown` if not set.
GetConstructionContextInternal()751 ConstructionContext GetConstructionContextInternal() const {
752 return construction_context_;
753 }
754
755 // TODO(josh11b): uint64 hash() const;
756
757 private:
758 // If cost_node is non-null, then cost accounting (in CostModel)
759 // will be associated with that node rather than the new one being
760 // created.
761 //
762 // Ownership of the returned Node is not transferred to caller.
763 Node* AllocateNode(std::shared_ptr<NodeProperties> props,
764 const Node* cost_node, Node::NodeClass node_class);
765 void ReleaseNode(Node* node);
766 // Insert edge in free_edges_ for possible reuse.
767 void RecycleEdge(const Edge* edge);
768 // Registry of all known ops, including functions.
769 FunctionLibraryDefinition ops_;
770
771 // GraphDef versions
772 const std::unique_ptr<VersionDef> versions_;
773
774 // Allocator which will give us good locality.
775 core::Arena arena_;
776
777 // Map from node ids to allocated nodes. nodes_[id] may be nullptr if
778 // the node with that id was removed from the graph.
779 std::vector<Node*> nodes_;
780
781 // Number of nodes alive.
782 int64_t num_nodes_ = 0;
783
784 // Map from edge ids to allocated edges. edges_[id] may be nullptr if
785 // the edge with that id was removed from the graph.
786 std::vector<Edge*> edges_;
787
788 // The number of entries in edges_ that are not nullptr.
789 int num_edges_ = 0;
790
791 // Allocated but free nodes and edges.
792 std::vector<Node*> free_nodes_;
793 std::vector<Edge*> free_edges_;
794
795 // For generating unique names.
796 int name_counter_ = 0;
797
798 // In most graphs, the number of unique values used for the
799 // Node::assigned_device_name() property is quite small. If the graph is
800 // large, then this duplication of values can consume a significant amount of
801 // memory. Instead, we represent the same information using an interning
802 // table, which consists of a vector of unique strings (device_names_), as
803 // well a map (device_names_map_) from unique strings to indices within the
804 // unique string table.
805 //
806 // The InternDeviceName() method handles adding a new entry into the table,
807 // or locating the index of an existing entry.
808 //
809 // The fact that Node::assigned_device_name() is implemented using an
810 // interning table is intentionally public. This allows algorithms that
811 // frequently access this field to do so efficiently, especially for the case
812 // where the assigned_device_name of one Node is copied directly from that
813 // of another Node.
814
815 // A table of the unique assigned device names. Indices do NOT correspond
816 // to node IDs. Index 0 is always the empty string.
817 std::vector<string> device_names_;
818
819 // Maps unique device names to indices within device_names_[i].
820 std::unordered_map<string, int> device_names_map_;
821
822 // All the while contexts owned by this graph, keyed by frame name,
823 // corresponding to all the while loops contained in this graph (including
824 // nested loops). The stored contexts are usually accessed via
825 // AddWhileContext() or Node::while_ctx(), but this manages the lifetime.
826 std::map<string, WhileContext> while_ctxs_;
827
828 // Cache of the indices of the arguments which need to be constant for the XLA
829 // compilation.
830 mutable absl::optional<std::vector<bool>> const_arg_indices_cache_;
831
832 // Indicates the context that this Graph instance is constructed.
833 ConstructionContext construction_context_ = ConstructionContext::kNotTracked;
834
835 TF_DISALLOW_COPY_AND_ASSIGN(Graph);
836 };
837
838 // TODO(josh11b): We may want to support keeping an index on various
839 // node/edge attributes in a graph, particularly node names.
840
841 // Helper routines
842
IsSource(const Node * node)843 inline bool IsSource(const Node* node) { return node->IsSource(); }
IsSink(const Node * node)844 inline bool IsSink(const Node* node) { return node->IsSink(); }
IsSwitch(const Node * node)845 inline bool IsSwitch(const Node* node) { return node->IsSwitch(); }
IsMerge(const Node * node)846 inline bool IsMerge(const Node* node) { return node->IsMerge(); }
IsEnter(const Node * node)847 inline bool IsEnter(const Node* node) { return node->IsEnter(); }
IsExit(const Node * node)848 inline bool IsExit(const Node* node) { return node->IsExit(); }
IsNextIteration(const Node * n)849 inline bool IsNextIteration(const Node* n) { return n->IsNextIteration(); }
IsLoopCond(const Node * node)850 inline bool IsLoopCond(const Node* node) { return node->IsLoopCond(); }
IsControlTrigger(const Node * n)851 inline bool IsControlTrigger(const Node* n) { return n->IsControlTrigger(); }
IsSend(const Node * node)852 inline bool IsSend(const Node* node) { return node->IsSend(); }
IsRecv(const Node * node)853 inline bool IsRecv(const Node* node) { return node->IsRecv(); }
IsHostSend(const Node * node)854 inline bool IsHostSend(const Node* node) { return node->IsHostSend(); }
IsHostRecv(const Node * node)855 inline bool IsHostRecv(const Node* node) { return node->IsHostRecv(); }
856
857 // True for Nodes that mediate the transfer of values between processes.
IsTransferNode(const Node * n)858 inline bool IsTransferNode(const Node* n) { return IsSend(n) || IsRecv(n); }
859
IsConstant(const Node * node)860 inline bool IsConstant(const Node* node) { return node->IsConstant(); }
IsVariable(const Node * node)861 inline bool IsVariable(const Node* node) { return node->IsVariable(); }
IsIdentity(const Node * node)862 inline bool IsIdentity(const Node* node) { return node->IsIdentity(); }
863
864 // Returns true iff 'n' is a control flow node.
IsControlFlow(const Node * n)865 inline bool IsControlFlow(const Node* n) { return n->IsControlFlow(); }
866
867 // Returns true if the node only depends on its input's metadata
868 // (shape). Specifically, returns true for "Size", "Shape" and "Rank" ops.
IsMetadata(const Node * n)869 inline bool IsMetadata(const Node* n) { return n->IsMetadata(); }
870
IsScopedAllocator(const Node * n)871 inline bool IsScopedAllocator(const Node* n) { return n->IsScopedAllocator(); }
872
IsHostMemoryPreserving(const Node * node)873 inline bool IsHostMemoryPreserving(const Node* node) {
874 return IsIdentity(node) || IsControlFlow(node);
875 }
876
IsDistributedCommunication(const Node * n)877 inline bool IsDistributedCommunication(const Node* n) {
878 return n->IsDistributedCommunication();
879 }
880
881 // NOTE: We declare Reference type of NodeIter and NeighborIter as Node* (see
882 // https://en.cppreference.com/w/cpp/iterator/iterator).
883
884 // Iterator for stepping through the nodes of a graph.
885 class NodeIter
886 : public std::iterator<std::forward_iterator_tag, Node, std::ptrdiff_t,
887 /*Pointer*/ Node*, /*Reference*/ Node*> {
888 public:
889 NodeIter(const Graph* graph, int id);
890 bool operator==(const NodeIter& rhs) const;
891 bool operator!=(const NodeIter& rhs) const;
892 void operator++();
893 reference operator*() const;
894 pointer operator->() const;
895
896 private:
897 // Invariant: id_ == graph_->num_node_ids() || graph_->FindId(id_) != nullptr
898 const Graph* graph_;
899 int id_;
900 };
901
902 // Iterator for stepping through the neighbors of a node.
903 class NeighborIter
904 : public std::iterator<std::forward_iterator_tag, Node, std::ptrdiff_t,
905 /*Pointer*/ Node*, /*Reference*/ Node*> {
906 public:
907 NeighborIter(EdgeSet::const_iterator iter, bool incoming);
908 bool operator==(const NeighborIter& rhs) const;
909 bool operator!=(const NeighborIter& rhs) const;
910 void operator++();
911 reference operator*() const;
912 pointer operator->() const;
913
914 private:
915 EdgeSet::const_iterator iter_;
916 bool incoming_;
917 };
918
919 // IMPLEMENTATION DETAILS, PLEASE IGNORE
920
NodeIter(const Graph * graph,int id)921 inline NodeIter::NodeIter(const Graph* graph, int id)
922 : graph_(graph), id_(id) {}
923
924 inline bool NodeIter::operator==(const NodeIter& rhs) const {
925 DCHECK(graph_ == rhs.graph_);
926 return id_ == rhs.id_;
927 }
928
929 inline bool NodeIter::operator!=(const NodeIter& rhs) const {
930 return !(*this == rhs);
931 }
932
933 inline void NodeIter::operator++() {
934 while (1) {
935 DCHECK_LE(id_, graph_->num_node_ids());
936 ++id_;
937 if (id_ >= graph_->num_node_ids() || graph_->FindNodeId(id_) != nullptr) {
938 return;
939 }
940 }
941 }
942
943 inline Node* NodeIter::operator*() const { return graph_->FindNodeId(id_); }
944
945 inline Node* NodeIter::operator->() const { return graph_->FindNodeId(id_); }
946
NeighborIter(EdgeSet::const_iterator iter,bool incoming)947 inline NeighborIter::NeighborIter(EdgeSet::const_iterator iter, bool incoming)
948 : iter_(iter), incoming_(incoming) {}
949
950 inline bool NeighborIter::operator==(const NeighborIter& rhs) const {
951 return iter_ == rhs.iter_ && incoming_ == rhs.incoming_;
952 }
953
954 inline bool NeighborIter::operator!=(const NeighborIter& rhs) const {
955 return !(*this == rhs);
956 }
957
958 inline void NeighborIter::operator++() { ++iter_; }
959
960 inline Node* NeighborIter::operator*() const {
961 const Edge* e = *iter_;
962 return incoming_ ? e->src() : e->dst();
963 }
964
965 inline Node* NeighborIter::operator->() const {
966 const Edge* e = *iter_;
967 return incoming_ ? e->src() : e->dst();
968 }
969
IsControlEdge()970 inline bool Edge::IsControlEdge() const {
971 // Note that if either src_output_ or dst_input_ is kControlSlot,
972 // so is the other one (AddEdge checks this).
973 return src_output_ == Graph::kControlSlot;
974 }
975
nodes()976 inline gtl::iterator_range<NodeIter> Graph::nodes() const {
977 // Note that NodeId 0 is always valid since we don't let the source
978 // node be removed from the graph.
979 return gtl::make_range(NodeIter(this, 0), NodeIter(this, num_node_ids()));
980 }
981
op_nodes()982 inline gtl::iterator_range<NodeIter> Graph::op_nodes() const {
983 // Note that NodeId 0 is always valid since we don't let the source
984 // node be removed from the graph.
985 //
986 // The current implementation of Graph maintains the invariant that the
987 // first two nodes are the source and sink nodes, and all other nodes are op
988 // nodes. This method (op_nodes()) relies on this invariant.
989 NodeIter begin(this, 0);
990 NodeIter end(this, num_node_ids());
991 if (begin != end) {
992 ++begin;
993 }
994 if (begin != end) {
995 ++begin;
996 }
997 return gtl::make_range(begin, end);
998 }
999
set_assigned_device_name_index(int index)1000 inline void Node::set_assigned_device_name_index(int index) {
1001 graph_->CheckDeviceNameIndex(index);
1002 assigned_device_name_index_ = index;
1003 }
1004
set_assigned_device_name(const std::string & device_name)1005 inline void Node::set_assigned_device_name(const std::string& device_name) {
1006 graph_->set_assigned_device_name(this, device_name);
1007 }
1008
assigned_device_name()1009 inline const std::string& Node::assigned_device_name() const {
1010 return graph_->get_assigned_device_name(*this);
1011 }
1012
1013 } // namespace tensorflow
1014
1015 #endif // TENSORFLOW_CORE_GRAPH_GRAPH_H_
1016