1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_VIEW_H_ 17 #define TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_VIEW_H_ 18 19 #include <memory> 20 #include <vector> 21 22 #include "tensorflow/core/framework/allocator.h" 23 #include "tensorflow/core/framework/types.h" 24 #include "tensorflow/core/lib/core/status.h" 25 #include "tensorflow/core/lib/gtl/array_slice.h" 26 #include "tensorflow/core/platform/logging.h" 27 #include "tensorflow/core/platform/macros.h" 28 #include "tensorflow/core/platform/types.h" 29 30 namespace tensorflow { 31 32 class Device; 33 class Graph; 34 class Node; 35 class OpKernel; 36 class Tensor; 37 38 // Represents a single data edge in a `NodeItem`. 39 struct EdgeInfo { 40 // The node ID of the destination in the containing `GraphView`. 41 int dst_id; 42 // The index of the output that produces values on this edge. 43 int output_slot : 31; 44 // true if this is the last info for output_slot in the EdgeInfo list. 45 bool is_last : 1; 46 // The index of the input that consumes values on this edge. 47 int input_slot; 48 }; 49 50 // Represents a single control edge in a `NodeItem`. 51 struct ControlEdgeInfo { 52 // The node ID of the destination in the containing `GraphView`. 53 int dst_id; 54 }; 55 56 // Compact structure representing a graph node and its associated kernel. 57 // 58 // Each NodeItem is an element of exactly one GraphView. 59 struct NodeItem { 60 // The index of this node's item in its GraphView. 61 int node_id = -1; 62 63 // Cached attributes of this node for fast lookup. 64 bool kernel_is_async : 1; // True iff kernel->AsAsync() != nullptr 65 bool is_merge : 1; // True iff IsMerge(node) 66 bool is_enter : 1; // True iff IsEnter(node) 67 bool is_constant_enter : 1; // True iff IsEnter(node) and 68 // node->GetAttr("is_constant") == true. 69 bool is_exit : 1; // True iff IsExit(node) 70 bool is_control_trigger : 1; // True iff IsControlTrigger(node) 71 bool is_source : 1; // True iff IsSource(node) 72 // True iff IsEnter(node) || IsExit(node) || IsNextIteration(node) 73 bool is_enter_exit_or_next_iter : 1; 74 bool is_transfer_node : 1; // True iff IsTransferNode(node) 75 bool is_initialization_op : 1; // True iff IsInitializationOp(node) 76 bool is_recv_or_switch : 1; // True iff IsRecv(node) || IsSwitch(node) 77 bool is_next_iteration : 1; // True iff IsNextIteration(node) 78 bool is_noop : 1; // True iff item->kernel->type_string_view() == "NoOp") 79 bool 80 is_any_consumer_merge_or_control_trigger : 1; // True iff the destination 81 // of any output edge is a 82 // merge or control trigger 83 // node. 84 bool is_any_input_ref_typed : 1; // True iff any IsRefType(dt) for dt in this 85 // node's input types. 86 87 // The kernel for this node. 88 OpKernel* kernel = nullptr; 89 90 // If the kernel is a Const op, this containts points to the constant tensor. 91 const Tensor* const_tensor = nullptr; 92 93 // Cached values of node->num_inputs() and node->num_outputs(), to 94 // avoid levels of indirection. 95 int num_inputs; 96 int num_outputs; 97 98 // ExecutorImpl::tensors_[input_start] is the 1st positional input 99 // for this node. 100 int input_start = 0; 101 102 // Number of output edges, excluding control edges. 103 int32 num_output_edges; 104 105 // Number of output control edges. 106 int32 num_output_control_edges; 107 108 // If non-null, contains an array of num_outputs bools, where the ith bool 109 // is true if and only if the ith output is consumed by another node. 110 std::unique_ptr<bool[]> outputs_required; 111 mutable_output_edgesNodeItem112 gtl::MutableArraySlice<EdgeInfo> mutable_output_edges() { 113 return gtl::MutableArraySlice<EdgeInfo>(output_edge_base(), 114 num_output_edges); 115 } 116 output_edgesNodeItem117 gtl::ArraySlice<EdgeInfo> output_edges() const { 118 return gtl::ArraySlice<EdgeInfo>(output_edge_base(), num_output_edges); 119 } 120 output_control_edgesNodeItem121 gtl::ArraySlice<ControlEdgeInfo> output_control_edges() const { 122 return gtl::ArraySlice<const ControlEdgeInfo>(output_control_edge_base(), 123 num_output_control_edges); 124 } 125 input_typeNodeItem126 DataType input_type(int i) const { 127 DCHECK_LT(i, num_inputs); 128 return static_cast<DataType>(input_type_base()[i]); 129 } output_typeNodeItem130 DataType output_type(int i) const { 131 DCHECK_LT(i, num_outputs); 132 return static_cast<DataType>(output_type_base()[i]); 133 } 134 135 // Return array of per-output allocator attributes. output_attrsNodeItem136 const AllocatorAttributes* output_attrs() const { return output_attr_base(); } 137 138 // Return array of expected input index from which each output should 139 // be forwarded: 140 // kNeverForward (-2) for DO NOT FORWARD (must allocate). 141 // kNoReservation (-1) for no expected forwarding. 142 // 0... for forward from that input. forward_fromNodeItem143 const int* forward_from() const { return forward_from_base(); } 144 145 string DebugString() const; 146 147 private: 148 friend class GraphView; 149 NodeItemNodeItem150 NodeItem() {} 151 152 // Variable length section starts immediately after *this 153 // (uint8 is enough for DataType). 154 // EdgeInfo out_edges[num_output_edges]; 155 // ControlEdgeInfo out_control_edges[num_output_control_edges]; 156 // AllocatorAttributes output_attr[num_outputs]; 157 // int forward_from[num_outputs]; 158 // uint8 input_type[num_inputs]; 159 // uint8 output_type[num_outputs]; 160 161 // Return pointer to variable length section. varNodeItem162 char* var() const { 163 return const_cast<char*>(reinterpret_cast<const char*>(this) + 164 sizeof(NodeItem)); 165 } 166 output_edge_baseNodeItem167 EdgeInfo* output_edge_base() const { 168 return reinterpret_cast<EdgeInfo*>(var()); 169 } 170 output_control_edge_baseNodeItem171 ControlEdgeInfo* output_control_edge_base() const { 172 return reinterpret_cast<ControlEdgeInfo*>(var() + sizeof(EdgeInfo) * 173 num_output_edges); 174 } 175 output_attr_baseNodeItem176 AllocatorAttributes* output_attr_base() const { 177 return reinterpret_cast<AllocatorAttributes*>( 178 var() + sizeof(EdgeInfo) * num_output_edges + 179 sizeof(ControlEdgeInfo) * num_output_control_edges); 180 } forward_from_baseNodeItem181 int* forward_from_base() const { 182 return reinterpret_cast<int*>(var() + sizeof(EdgeInfo) * num_output_edges + 183 sizeof(ControlEdgeInfo) * 184 num_output_control_edges + 185 sizeof(AllocatorAttributes) * num_outputs); 186 } input_type_baseNodeItem187 uint8* input_type_base() const { 188 return reinterpret_cast<uint8*>( 189 var() + sizeof(EdgeInfo) * num_output_edges + 190 sizeof(ControlEdgeInfo) * num_output_control_edges + 191 sizeof(AllocatorAttributes) * num_outputs + sizeof(int) * num_outputs); 192 } output_type_baseNodeItem193 uint8* output_type_base() const { 194 return reinterpret_cast<uint8*>( 195 var() + sizeof(EdgeInfo) * num_output_edges + 196 sizeof(ControlEdgeInfo) * num_output_control_edges + 197 sizeof(AllocatorAttributes) * num_outputs + sizeof(int) * num_outputs + 198 sizeof(uint8) * num_inputs); 199 } 200 201 TF_DISALLOW_COPY_AND_ASSIGN(NodeItem); 202 }; 203 204 // Immutable view of a Graph organized for efficient execution. 205 // 206 // TODO(b/152651962): Add independent unit tests for this class. 207 class GraphView { 208 public: GraphView()209 GraphView() : space_(nullptr) {} 210 ~GraphView(); 211 212 Status Initialize(const Graph* g); 213 Status SetAllocAttrs(const Graph* g, const Device* device); 214 void SetScopedAllocatorAttrs(const std::vector<const Node*>& sa_nodes); 215 216 // Returns a mutable pointer to the `NodeItem` with the given `id` if it 217 // exists in the graph, or `nullptr` if it does not. node(int32 id)218 NodeItem* node(int32 id) const { 219 DCHECK_GE(id, 0); 220 DCHECK_LT(id, num_nodes_); 221 uint32 offset = node_offsets_[id]; 222 return ((offset == kuint32max) 223 ? nullptr 224 : reinterpret_cast<NodeItem*>(space_ + node_offsets_[id])); 225 } 226 227 // Returns the `NodeItem` with the given `id`. 228 // 229 // REQUIRES: `id` must be the ID of a valid node in the graph. node_ref(int32 id)230 const NodeItem& node_ref(int32 id) const { 231 DCHECK_GE(id, 0); 232 DCHECK_LT(id, num_nodes_); 233 uint32 offset = node_offsets_[id]; 234 DCHECK_NE(offset, kuint32max); 235 return *reinterpret_cast<NodeItem*>(space_ + node_offsets_[id]); 236 } 237 num_nodes()238 int32 num_nodes() const { return num_nodes_; } 239 240 private: 241 char* InitializeNode(char* ptr, const Node* n); 242 size_t NodeItemBytes(const Node* n); 243 244 int32 num_nodes_ = 0; 245 uint32* node_offsets_ = nullptr; // array of size "num_nodes_" 246 // node_offsets_[id] holds the byte offset for node w/ "id" in space_ 247 248 char* space_; // NodeItem objects are allocated here 249 250 TF_DISALLOW_COPY_AND_ASSIGN(GraphView); 251 }; 252 253 } // namespace tensorflow 254 255 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_VIEW_H_ 256