1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_PROPAGATOR_STATE_H_ 16 #define TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_PROPAGATOR_STATE_H_ 17 18 #include <vector> 19 20 #include "tensorflow/core/common_runtime/entry.h" 21 #include "tensorflow/core/common_runtime/immutable_executor_state.h" 22 #include "tensorflow/core/common_runtime/pending_counts.h" 23 #include "tensorflow/core/framework/control_flow.h" 24 #include "tensorflow/core/lib/gtl/inlined_vector.h" 25 #include "tensorflow/core/platform/logging.h" 26 #include "tensorflow/core/platform/macros.h" 27 #include "tensorflow/core/platform/mutex.h" 28 #include "tensorflow/core/platform/thread_annotations.h" 29 #include "tensorflow/core/platform/types.h" 30 31 namespace tensorflow { 32 33 // Represents the ephemeral "edge state" associated with one invocation of 34 // `Executor::Run()`. 35 // 36 // NOTE: `SimplePropagatorState` does not support "v1-style" control flow, 37 // including "dead tensors", "Switch" and "Merge" nodes, and cycles in the 38 // graph. Use `PropagatorState` for graphs with those features. 39 // `SimplePropagatorState` *does* support "v2-style" or "functional" control 40 // flow. 41 // 42 // `SimplePropagatorState` is responsible for propagating values along dataflow 43 // edges in a TensorFlow graph and determining which nodes are runnable. The 44 // executor primarily updates `SimplePropagatorState` by calling 45 // `PropagateOutputs()` after processing a node, and `SimplePropagatorState` 46 // dispatches `TaggedNode`s by adding them to a `TaggedNodeSeq`. 47 class SimplePropagatorState { 48 public: 49 SimplePropagatorState(const ImmutableExecutorState& immutable_state, 50 int64 step_id, bool vlog); 51 ~SimplePropagatorState(); 52 53 // A `TaggedNode` corresponds to a single invocation of a node's kernel, 54 // and it is created when the kernel becomes runnable. 55 struct TaggedNode { 56 const NodeItem* node_item; 57 TaggedNodeTaggedNode58 explicit TaggedNode(const NodeItem* node_item) : node_item(node_item) {} 59 get_node_itemTaggedNode60 const NodeItem& get_node_item() const { return *node_item; } 61 get_is_deadTaggedNode62 bool get_is_dead() const { return false; } get_iter_numTaggedNode63 int64 get_iter_num() const { return 0; } 64 }; 65 66 // A drop-in replacement for std::deque<TaggedNode>. We typically don't 67 // have that many nodes in the ready queue, so we just use a vector and 68 // don't free up memory from the queue as we consume nodes. 69 // TODO(mrry): Extract this and share it with the version in 70 // `PropagatorState`. The correct constants might be different, since 71 // sizeof(TaggedNode) is smaller in this version. 72 class TaggedNodeReadyQueue { 73 public: TaggedNodeReadyQueue()74 TaggedNodeReadyQueue() : front_index_(0) {} 75 push_back(const TaggedNode & node)76 void push_back(const TaggedNode& node) { ready_.push_back(node); } front()77 TaggedNode front() const { 78 DCHECK_LT(front_index_, ready_.size()); 79 return ready_[front_index_]; 80 } pop_front()81 void pop_front() { 82 DCHECK_LT(front_index_, ready_.size()); 83 front_index_++; 84 if ((front_index_ == ready_.size()) || (front_index_ > kSpillThreshold)) { 85 if (front_index_ == ready_.size()) { 86 ready_.clear(); 87 } else { 88 // Lots of unused entries at beginning of vector: move everything 89 // down to start of vector. 90 ready_.erase(ready_.begin(), ready_.begin() + front_index_); 91 } 92 front_index_ = 0; 93 } 94 } empty()95 bool empty() const { return ready_.empty(); } 96 97 private: 98 // TODO(b/152925936): Re-evaluate these constants with current usage 99 // patterns. 100 static constexpr int kSpillThreshold = 16384; 101 gtl::InlinedVector<TaggedNode, 16> ready_; 102 int front_index_; 103 }; 104 105 // TODO(b/152925936): Re-evaluate this constant with current usage patterns. 106 typedef gtl::InlinedVector<TaggedNode, 8> TaggedNodeSeq; 107 108 // Creates and adds a `TaggedNode` for each node in `roots` to `*ready`. 109 void ActivateRoots(gtl::ArraySlice<const NodeItem*> roots, 110 TaggedNodeSeq* ready); 111 112 // After processing the outputs, propagates the outputs to their dsts. 113 // Contents of *outputs are left in an indeterminate state after 114 // returning from this method. 115 void PropagateOutputs(const TaggedNode& tagged_node, EntryVector* outputs, 116 TaggedNodeSeq* ready); 117 118 // Returns an array of `Entry` objects corresponding to the inputs of 119 // `tagged_node`. GetInputTensors(const TaggedNode & tagged_node)120 Entry* GetInputTensors(const TaggedNode& tagged_node) { 121 #if defined(THREAD_SANITIZER) || defined(DEBUG) 122 // NOTE: This read of `pending_[...]` works around a limitation in TSAN. 123 // To avoid false positive data race reports, we need to perform an atomic 124 // object access that will establish the happens-before relation between 125 // the write to input_tensors_ in `PropagateOutputs()` and the read in 126 // `PrepareInputs()`. 127 CHECK_EQ(pending_[tagged_node.node_item->node_id], 0); 128 #endif // defined(THREAD_SANITIZER) || defined(DEBUG) 129 return input_tensors_.data() + tagged_node.node_item->input_start; 130 } 131 GetFrameAndIter(const TaggedNode & tagged_node)132 FrameAndIter GetFrameAndIter(const TaggedNode& tagged_node) const { 133 return {0, 0}; 134 } 135 136 // Provide debugging output of the state of the executor. 137 void DumpState(); 138 139 // For debugging/logging only. MaybeMarkStarted(const TaggedNode & tagged_node)140 void MaybeMarkStarted(const TaggedNode& tagged_node) { 141 // TODO(misard) Replace with a finer-grain enabling flag once we add better 142 // optional debugging support. 143 if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) { 144 mutex_lock l(mu_); 145 (*active_)[tagged_node.node_item->node_id] = true; 146 } 147 } MaybeMarkCompleted(const TaggedNode & tagged_node)148 void MaybeMarkCompleted(const TaggedNode& tagged_node) { 149 // TODO(misard) Replace with a finer-grain enabling flag once we add better 150 // optional debugging support. 151 if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) { 152 mutex_lock l(mu_); 153 (*active_)[tagged_node.node_item->node_id] = false; 154 } 155 } 156 157 private: 158 SimplePropagatorState(const ImmutableExecutorState& immutable_state_, 159 int64 step_id, 160 const ImmutableExecutorState::FrameInfo& finfo, 161 bool vlog); 162 163 const ImmutableExecutorState& immutable_state_; 164 const int64 step_id_; 165 const bool vlog_; 166 167 // The i-th node's j-th input is stored at 168 // `input_tensors[impl_->nodes[i].input_start + j]`. 169 // 170 // NOTE: No need to protect input_tensors[i] by any locks because it 171 // is resized once. Each element of input_tensors is written once by the 172 // source node of an edge and is cleared by the destination of the same 173 // edge. The destination node always runs after the source node, so there 174 // is never concurrent access to the same entry. 175 std::vector<Entry> input_tensors_; 176 177 std::unique_ptr<std::atomic<int32>[]> pending_; 178 179 // If `vlog_` is true, this stores a bit vector of active nodes, indexed by 180 // node ID. 181 mutex mu_; 182 std::unique_ptr<std::vector<bool>> active_ TF_GUARDED_BY(mu_); 183 184 const std::vector<const NodeItem*>* const nodes_; 185 }; 186 187 } // namespace tensorflow 188 189 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_PROPAGATOR_STATE_H_ 190