1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_IMMUTABLE_EXECUTOR_STATE_H_ 16 #define TENSORFLOW_CORE_COMMON_RUNTIME_IMMUTABLE_EXECUTOR_STATE_H_ 17 18 #include <atomic> 19 #include <deque> 20 #include <memory> 21 #include <vector> 22 23 #include "absl/container/flat_hash_map.h" 24 #include "tensorflow/core/common_runtime/graph_view.h" 25 #include "tensorflow/core/common_runtime/local_executor_params.h" 26 #include "tensorflow/core/common_runtime/pending_counts.h" 27 #include "tensorflow/core/framework/tensor.h" 28 #include "tensorflow/core/lib/core/status.h" 29 #include "tensorflow/core/lib/gtl/flatmap.h" 30 #include "tensorflow/core/lib/gtl/flatset.h" 31 #include "tensorflow/core/platform/macros.h" 32 #include "tensorflow/core/platform/types.h" 33 34 namespace tensorflow { 35 36 class Graph; 37 38 // Represents the state of an executor (graph and control flow information) 39 // that is immutable throughout execution. 40 // 41 // TODO(b/152651962): Add independent unit tests for this class. 42 class ImmutableExecutorState { 43 public: 44 struct FrameInfo { FrameInfoFrameInfo45 explicit FrameInfo(string name) 46 : name(std::move(name)), 47 input_count(0), 48 total_inputs(0), 49 pending_counts(nullptr), 50 nodes(nullptr), 51 parallel_iterations(-1) {} 52 53 // The name of the frame. 54 string name; 55 56 // The total number of inputs to a frame. 57 int input_count; 58 59 // The total number of input tensors of a frame. 60 // == sum(nodes[*].num_inputs()) where nodes are the nodes in the frame. 61 int total_inputs; 62 63 // Used to determine the next place to allocate space in the 64 // pending_counts data structure we'll eventually construct 65 PendingCounts::Layout pending_counts_layout; 66 67 // Each frame has its own PendingCounts only for the nodes in the frame. 68 std::unique_ptr<PendingCounts> pending_counts; 69 70 // The nodes in a frame. Used only for debugging. 71 std::unique_ptr<std::vector<const NodeItem*>> nodes; 72 73 // The number of iterations of this frame that can execute concurrently. 74 int32 parallel_iterations; 75 }; 76 ImmutableExecutorState(const LocalExecutorParams & p)77 explicit ImmutableExecutorState(const LocalExecutorParams& p) 78 : params_(p), gview_() {} 79 ~ImmutableExecutorState(); 80 81 Status Initialize(const Graph& graph); 82 83 // Process all Nodes in the current graph, attempting to infer the 84 // memory allocation attributes to be used wherever they may allocate 85 // a tensor buffer. 86 Status SetAllocAttrs(); 87 params()88 const LocalExecutorParams& params() const { return params_; } graph_view()89 const GraphView& graph_view() const { return gview_; } pending_ids()90 const std::vector<PendingCounts::Handle>& pending_ids() const { 91 return pending_ids_; 92 } root_nodes()93 const std::vector<const NodeItem*>& root_nodes() const { return root_nodes_; } 94 get_root_frame_info()95 const FrameInfo& get_root_frame_info() const { return *root_frame_info_; } 96 get_enter_frame_info(const NodeItem & node_item)97 const FrameInfo& get_enter_frame_info(const NodeItem& node_item) const { 98 DCHECK(node_item.is_enter); 99 return *enter_frame_info_[node_item.node_id]; 100 } 101 requires_control_flow_support()102 bool requires_control_flow_support() const { return requires_control_flow_; } 103 104 // Copies the pending counts for nodes in this graph to the given array. 105 // 106 // This method provides a more efficient way of initializing 107 // `SimplePropagatorState` than individually accessing the pending counts from 108 // `get_root_frame_info().counts`. 109 // 110 // REQUIRES: `!requires_control_flow_support && len(dest) == 111 // graph_view().num_nodes()`. copy_pending_counts(std::atomic<int32> * dest)112 void copy_pending_counts(std::atomic<int32>* dest) const { 113 DCHECK(!requires_control_flow_); 114 memcpy(dest, atomic_pending_counts_.get(), 115 graph_view().num_nodes() * sizeof(std::atomic<int32>)); 116 std::atomic_thread_fence(std::memory_order_release); 117 } 118 119 private: 120 struct ControlFlowInfo { 121 gtl::FlatSet<string> unique_frame_names; 122 std::vector<string> frame_names; 123 }; 124 125 static Status BuildControlFlowInfo(const Graph* graph, 126 ControlFlowInfo* cf_info); 127 void InitializePending(const Graph* graph, const ControlFlowInfo& cf_info); 128 129 FrameInfo* EnsureFrameInfo(const string& fname); 130 131 // Owned. 132 LocalExecutorParams params_; 133 GraphView gview_; 134 bool requires_control_flow_; 135 std::vector<PendingCounts::Handle> pending_ids_; 136 137 // Root nodes (with no in edges) that should form the initial ready queue 138 std::vector<const NodeItem*> root_nodes_; 139 140 // Mapping from frame name to static information about the frame. 141 // TODO(yuanbyu): We could cache it along with the graph so to avoid 142 // the overhead of constructing it for each executor instance. 143 absl::flat_hash_map<absl::string_view, std::unique_ptr<FrameInfo>> 144 frame_info_; 145 const FrameInfo* root_frame_info_; // Not owned. 146 147 // If the graph contains any "Enter" or "RefEnter" nodes, this vector maps 148 // dense node IDs to the corresponding FrameInfo. 149 std::vector<FrameInfo*> enter_frame_info_; 150 151 // If `requires_control_flow_` is false, this points to an array of initial 152 // pending counts for the nodes in the graph, indexed by node ID. 153 std::unique_ptr<std::atomic<int32>[]> atomic_pending_counts_; 154 155 // Shallow copies of the constant tensors used in the graph. 156 std::vector<Tensor> const_tensors_; 157 158 TF_DISALLOW_COPY_AND_ASSIGN(ImmutableExecutorState); 159 }; 160 161 } // namespace tensorflow 162 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_IMMUTABLE_EXECUTOR_STATE_H_ 163