• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_PROPAGATOR_STATE_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_PROPAGATOR_STATE_H_
17 
18 #include <vector>
19 
20 #include "tensorflow/core/common_runtime/entry.h"
21 #include "tensorflow/core/common_runtime/immutable_executor_state.h"
22 #include "tensorflow/core/common_runtime/pending_counts.h"
23 #include "tensorflow/core/framework/control_flow.h"
24 #include "tensorflow/core/lib/gtl/inlined_vector.h"
25 #include "tensorflow/core/platform/logging.h"
26 #include "tensorflow/core/platform/macros.h"
27 #include "tensorflow/core/platform/mutex.h"
28 #include "tensorflow/core/platform/thread_annotations.h"
29 #include "tensorflow/core/platform/types.h"
30 
31 namespace tensorflow {
32 
33 // Represents the ephemeral "edge state" associated with one invocation of
34 // `Executor::Run()`.
35 //
36 // NOTE: `SimplePropagatorState` does not support "v1-style" control flow,
37 // including "dead tensors", "Switch" and "Merge" nodes, and cycles in the
38 // graph. Use `PropagatorState` for graphs with those features.
39 // `SimplePropagatorState` *does* support "v2-style" or "functional" control
40 // flow.
41 //
42 // `SimplePropagatorState` is responsible for propagating values along dataflow
43 // edges in a TensorFlow graph and determining which nodes are runnable. The
44 // executor primarily updates `SimplePropagatorState` by calling
45 // `PropagateOutputs()` after processing a node, and `SimplePropagatorState`
46 // dispatches `TaggedNode`s by adding them to a `TaggedNodeSeq`.
47 class SimplePropagatorState {
48  public:
49   SimplePropagatorState(const ImmutableExecutorState& immutable_state,
50                         int64 step_id, bool vlog);
51   ~SimplePropagatorState();
52 
53   // A `TaggedNode` corresponds to a single invocation of a node's kernel,
54   // and it is created when the kernel becomes runnable.
55   struct TaggedNode {
56     const NodeItem* node_item;
57 
TaggedNodeTaggedNode58     explicit TaggedNode(const NodeItem* node_item) : node_item(node_item) {}
59 
get_node_itemTaggedNode60     const NodeItem& get_node_item() const { return *node_item; }
61 
get_is_deadTaggedNode62     bool get_is_dead() const { return false; }
get_iter_numTaggedNode63     int64 get_iter_num() const { return 0; }
64   };
65 
66   // A drop-in replacement for std::deque<TaggedNode>.  We typically don't
67   // have that many nodes in the ready queue, so we just use a vector and
68   // don't free up memory from the queue as we consume nodes.
69   // TODO(mrry): Extract this and share it with the version in
70   // `PropagatorState`. The correct constants might be different, since
71   // sizeof(TaggedNode) is smaller in this version.
72   class TaggedNodeReadyQueue {
73    public:
TaggedNodeReadyQueue()74     TaggedNodeReadyQueue() : front_index_(0) {}
75 
push_back(const TaggedNode & node)76     void push_back(const TaggedNode& node) { ready_.push_back(node); }
front()77     TaggedNode front() const {
78       DCHECK_LT(front_index_, ready_.size());
79       return ready_[front_index_];
80     }
pop_front()81     void pop_front() {
82       DCHECK_LT(front_index_, ready_.size());
83       front_index_++;
84       if ((front_index_ == ready_.size()) || (front_index_ > kSpillThreshold)) {
85         if (front_index_ == ready_.size()) {
86           ready_.clear();
87         } else {
88           // Lots of unused entries at beginning of vector: move everything
89           // down to start of vector.
90           ready_.erase(ready_.begin(), ready_.begin() + front_index_);
91         }
92         front_index_ = 0;
93       }
94     }
empty()95     bool empty() const { return ready_.empty(); }
96 
97    private:
98     // TODO(b/152925936): Re-evaluate these constants with current usage
99     // patterns.
100     static constexpr int kSpillThreshold = 16384;
101     gtl::InlinedVector<TaggedNode, 16> ready_;
102     int front_index_;
103   };
104 
105   // TODO(b/152925936): Re-evaluate this constant with current usage patterns.
106   typedef gtl::InlinedVector<TaggedNode, 8> TaggedNodeSeq;
107 
108   // Creates and adds a `TaggedNode` for each node in `roots` to `*ready`.
109   void ActivateRoots(gtl::ArraySlice<const NodeItem*> roots,
110                      TaggedNodeSeq* ready);
111 
112   // After processing the outputs, propagates the outputs to their dsts.
113   // Contents of *outputs are left in an indeterminate state after
114   // returning from this method.
115   void PropagateOutputs(const TaggedNode& tagged_node, EntryVector* outputs,
116                         TaggedNodeSeq* ready);
117 
118   // Returns an array of `Entry` objects corresponding to the inputs of
119   // `tagged_node`.
GetInputTensors(const TaggedNode & tagged_node)120   Entry* GetInputTensors(const TaggedNode& tagged_node) {
121 #if defined(THREAD_SANITIZER) || defined(DEBUG)
122     // NOTE: This read of `pending_[...]` works around a limitation in TSAN.
123     // To avoid false positive data race reports, we need to perform an atomic
124     // object access that will establish the happens-before relation between
125     // the write to input_tensors_ in `PropagateOutputs()` and the read in
126     // `PrepareInputs()`.
127     CHECK_EQ(pending_[tagged_node.node_item->node_id], 0);
128 #endif  // defined(THREAD_SANITIZER) || defined(DEBUG)
129     return input_tensors_.data() + tagged_node.node_item->input_start;
130   }
131 
GetFrameAndIter(const TaggedNode & tagged_node)132   FrameAndIter GetFrameAndIter(const TaggedNode& tagged_node) const {
133     return {0, 0};
134   }
135 
136   // Provide debugging output of the state of the executor.
137   void DumpState();
138 
139   // For debugging/logging only.
MaybeMarkStarted(const TaggedNode & tagged_node)140   void MaybeMarkStarted(const TaggedNode& tagged_node) {
141     // TODO(misard) Replace with a finer-grain enabling flag once we add better
142     // optional debugging support.
143     if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) {
144       mutex_lock l(mu_);
145       (*active_)[tagged_node.node_item->node_id] = true;
146     }
147   }
MaybeMarkCompleted(const TaggedNode & tagged_node)148   void MaybeMarkCompleted(const TaggedNode& tagged_node) {
149     // TODO(misard) Replace with a finer-grain enabling flag once we add better
150     // optional debugging support.
151     if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) {
152       mutex_lock l(mu_);
153       (*active_)[tagged_node.node_item->node_id] = false;
154     }
155   }
156 
157  private:
158   SimplePropagatorState(const ImmutableExecutorState& immutable_state_,
159                         int64 step_id,
160                         const ImmutableExecutorState::FrameInfo& finfo,
161                         bool vlog);
162 
163   const ImmutableExecutorState& immutable_state_;
164   const int64 step_id_;
165   const bool vlog_;
166 
167   // The i-th node's j-th input is stored at
168   // `input_tensors[impl_->nodes[i].input_start + j]`.
169   //
170   // NOTE: No need to protect input_tensors[i] by any locks because it
171   // is resized once. Each element of input_tensors is written once by the
172   // source node of an edge and is cleared by the destination of the same
173   // edge. The destination node always runs after the source node, so there
174   // is never concurrent access to the same entry.
175   std::vector<Entry> input_tensors_;
176 
177   std::unique_ptr<std::atomic<int32>[]> pending_;
178 
179   // If `vlog_` is true, this stores a bit vector of active nodes, indexed by
180   // node ID.
181   mutex mu_;
182   std::unique_ptr<std::vector<bool>> active_ TF_GUARDED_BY(mu_);
183 
184   const std::vector<const NodeItem*>* const nodes_;
185 };
186 
187 }  // namespace tensorflow
188 
189 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_PROPAGATOR_STATE_H_
190