1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_STATE_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_STATE_H_
17
18 #include <vector>
19
20 #include "tensorflow/core/common_runtime/entry.h"
21 #include "tensorflow/core/common_runtime/immutable_executor_state.h"
22 #include "tensorflow/core/common_runtime/pending_counts.h"
23 #include "tensorflow/core/framework/allocator.h"
24 #include "tensorflow/core/framework/control_flow.h"
25 #include "tensorflow/core/lib/gtl/flatmap.h"
26 #include "tensorflow/core/lib/gtl/inlined_vector.h"
27 #include "tensorflow/core/platform/env.h"
28 #include "tensorflow/core/platform/logging.h"
29 #include "tensorflow/core/platform/macros.h"
30 #include "tensorflow/core/platform/mutex.h"
31 #include "tensorflow/core/platform/thread_annotations.h"
32 #include "tensorflow/core/platform/types.h"
33
34 namespace tensorflow {
35
36 typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec;
37
38 // Represents the ephemeral "edge state" associated with one invocation of
39 // `Executor::Run()`.
40 //
41 // `PropagatorState` is responsible for propagating values along dataflow
42 // edges in a TensorFlow graph and determining which nodes are runnable. The
43 // executor primarily updates `PropagatorState` by calling `PropagateOutputs()`
44 // after processing a node, and `PropagatorState` dispatches `TaggedNode`s by
45 // adding them to a `TaggedNodeSeq`.
46 class PropagatorState {
47 public:
48 PropagatorState(const ImmutableExecutorState& immutable_state,
49 int64_t step_id, bool vlog);
50 ~PropagatorState();
51
52 private:
53 // Forward declaration so that `TaggedNode` can include a `FrameState*` and an
54 // `IterationState*`.
55 struct FrameState;
56 struct IterationState;
57
58 public:
59 // A `TaggedNode` corresponds to a single invocation of a node's kernel,
60 // and it is created when the kernel becomes runnable (in a particular
61 // iteration of a particular frame).
62 struct TaggedNode {
63 const NodeItem* node_item;
64 FrameState* input_frame;
65 IterationState* input_iter;
66 bool is_dead;
67
68 TaggedNode() = default;
TaggedNodeTaggedNode69 TaggedNode(const NodeItem* node_item, FrameState* in_frame,
70 IterationState* in_iter, bool dead)
71 : node_item(node_item),
72 input_frame(in_frame),
73 input_iter(in_iter),
74 is_dead(dead) {}
75
get_node_itemTaggedNode76 const NodeItem& get_node_item() const { return *node_item; }
77
get_is_deadTaggedNode78 bool get_is_dead() const { return is_dead; }
79 int64 get_iter_num() const;
80 };
81
82 // A drop-in replacement for std::deque<TaggedNode>. We typically don't
83 // have that many nodes in the ready queue, so we just use a vector and
84 // don't free up memory from the queue as we consume nodes.
85 class TaggedNodeReadyQueue {
86 public:
TaggedNodeReadyQueue()87 TaggedNodeReadyQueue() : front_index_(0) {}
88
push_back(const TaggedNode & node)89 void push_back(const TaggedNode& node) { ready_.push_back(node); }
front()90 TaggedNode front() const {
91 DCHECK_LT(front_index_, ready_.size());
92 return ready_[front_index_];
93 }
pop_front()94 void pop_front() {
95 DCHECK_LT(front_index_, ready_.size());
96 front_index_++;
97 if ((front_index_ == ready_.size()) || (front_index_ > kSpillThreshold)) {
98 if (front_index_ == ready_.size()) {
99 ready_.clear();
100 } else {
101 // Lots of unused entries at beginning of vector: move everything
102 // down to start of vector.
103 ready_.erase(ready_.begin(), ready_.begin() + front_index_);
104 }
105 front_index_ = 0;
106 }
107 }
empty()108 bool empty() const { return ready_.empty(); }
109
110 private:
111 // TODO(b/152925936): Re-evaluate these constants with current usage
112 // patterns.
113 static constexpr int kSpillThreshold = 16384;
114 gtl::InlinedVector<TaggedNode, 16> ready_;
115 int front_index_;
116 };
117
118 // TODO(b/152925936): Re-evaluate this constant with current usage patterns.
119 typedef gtl::InlinedVector<TaggedNode, 8> TaggedNodeSeq;
120
121 private:
122 // The state of an iteration in a particular frame.
123 struct IterationState {
IterationStateIterationState124 explicit IterationState(int64_t iter_num,
125 const PendingCounts* pending_counts,
126 int total_input_tensors)
127 : iter_num(iter_num),
128 input_tensors(new Entry[total_input_tensors]),
129 outstanding_ops(0),
130 outstanding_frame_count(0),
131 counts(*pending_counts) { // Initialize with copy of *pending_counts
132 }
133
134 const int64 iter_num; // The index of this iteration in the enclosing loop.
135
136 // One copy per iteration. For iteration k, i-th node's j-th input is in
137 // input_tensors[k][immutable_state_.nodes[i].input_start + j]. An entry is
138 // either a tensor pointer (pass-by-reference) or a tensor (pass-by-value).
139 //
140 // NOTE: No need to protect input_tensors[i] by any locks because it
141 // is resized once. Each element of tensors_ is written once by the
142 // source node of an edge and is cleared by the destination of the same
143 // edge. The latter node is never run concurrently with the former node.
144 Entry* input_tensors;
145
146 // The number of outstanding ops for each iteration.
147 std::atomic<size_t> outstanding_ops;
148
149 // The number of outstanding frames for each iteration.
150 int outstanding_frame_count;
pendingIterationState151 int pending(PendingCounts::Handle h) { return counts.pending(h); }
decrement_pendingIterationState152 int decrement_pending(PendingCounts::Handle h, int v) {
153 return counts.decrement_pending(h, v);
154 }
155 // Mark a merge node as live
156 // REQUIRES: Node corresponding to "h" is a merge node
mark_liveIterationState157 void mark_live(PendingCounts::Handle h) { counts.mark_live(h); }
158 // Mark a node to show that processing has started.
mark_startedIterationState159 void mark_started(PendingCounts::Handle h) { counts.mark_started(h); }
160 // Mark a node to show that processing has completed.
mark_completedIterationState161 void mark_completed(PendingCounts::Handle h) { counts.mark_completed(h); }
node_stateIterationState162 PendingCounts::NodeState node_state(PendingCounts::Handle h) {
163 return counts.node_state(h);
164 }
165
dead_countIterationState166 int dead_count(PendingCounts::Handle h) { return counts.dead_count(h); }
increment_dead_countIterationState167 void increment_dead_count(PendingCounts::Handle h) {
168 counts.increment_dead_count(h);
169 }
adjust_for_activationIterationState170 PendingCounts::AdjustResult adjust_for_activation(PendingCounts::Handle h,
171 bool increment_dead) {
172 return counts.adjust_for_activation(h, increment_dead);
173 }
adjust_for_activation_atomicIterationState174 PendingCounts::AdjustResult adjust_for_activation_atomic(
175 PendingCounts::Handle h, bool increment_dead) {
176 return counts.adjust_for_activation_atomic(h, increment_dead);
177 }
178
~IterationStateIterationState179 ~IterationState() { delete[] input_tensors; }
180
181 private:
182 PendingCounts counts;
183 };
184
185 struct FrameState {
FrameStateFrameState186 explicit FrameState(const ImmutableExecutorState& immutable_state,
187 int parallel_iters)
188 : immutable_state(immutable_state),
189 max_parallel_iterations(parallel_iters),
190 num_outstanding_iterations(1),
191 iterations(parallel_iters + 1),
192 iterations_raw(iterations.data()) {}
193
194 // A new frame is created for each loop. Execution starts at iteration 0.
195 // When a value at iteration 0 passes through a NextIteration node,
196 // iteration 1 is created and starts running. Note that iteration 0 may
197 // still be running so multiple iterations may run in parallel. The
198 // frame maintains the state of iterations in several data structures
199 // such as pending_count and input_tensors. When iteration 0 completes,
200 // we garbage collect the state of iteration 0.
201 //
202 // A frame instance is considered "done" and can be garbage collected
203 // if all its inputs have entered and all its iterations are "done".
204 //
205 // A frame manages the live iterations of an iterative computation.
206 // Iteration i is considered "done" when there are no outstanding ops,
207 // frames at iteration i are done, all recvs for this iteration are
208 // completed, and iteration i-1 is done. For iteration 0, we instead
209 // wait for there to be no more pending inputs of the frame.
210 //
211 // Frames and iterations are garbage collected once they are done.
212 // The state we need to keep around is highly dependent on the
213 // parallelism enabled by the scheduler. We may want to have the
214 // scheduler dynamically control the outstanding number of live
215 // parallel frames and iterations. To reduce the state space, the
216 // scheduler might want to schedule ops in inner frames first and
217 // lower iterations first.
218 //
219 // This frame state is mostly initialized lazily on demand so we
220 // don't introduce unnecessary overhead.
221
222 // The immutable state of the executor the frame is in.
223 const ImmutableExecutorState& immutable_state;
224
225 // The name of this frame, which is the concatenation of its parent
226 // frame name, the iteration of the parent frame when this frame was
227 // created, and the value of the attr 'frame_name'.
228 string frame_name;
229
230 // The unique id for this frame. Generated by fingerprinting
231 // frame_name.
232 uint64 frame_id;
233
234 // The iteration state of its parent frame when this frame is created.
235 // nullptr if there is no parent frame. The frame_name/parent_iter pair
236 // uniquely identifies this FrameState.
237 IterationState* parent_iter = nullptr;
238
239 // The FrameState of its parent frame.
240 FrameState* parent_frame = nullptr;
241
242 // The maximum allowed number of parallel iterations.
243 const int max_parallel_iterations;
244
245 // The number of inputs this frame is still waiting.
246 int num_pending_inputs = 0;
247
248 // The highest iteration number we have reached so far in this frame.
249 int64 iteration_count TF_GUARDED_BY(mu) = 0;
250
251 // The number of outstanding iterations.
252 int num_outstanding_iterations TF_GUARDED_BY(mu) = 1;
253
254 private:
255 // The active iteration states of this frame.
256 gtl::InlinedVector<IterationState*, 12> iterations;
257 IterationState** const iterations_raw TF_GUARDED_BY(mu);
258 IterationState* iterations_first TF_GUARDED_BY(mu);
259
260 public:
261 // The NextIteration nodes to enter a new iteration. If the number of
262 // outstanding iterations reaches the limit, we will defer the start of
263 // the next iteration until the number of outstanding iterations falls
264 // below the limit.
265 std::vector<std::pair<const NodeItem*, Entry>> next_iter_roots
266 TF_GUARDED_BY(mu);
267
268 // The values of the loop invariants for this loop. They are added into
269 // this list as they "enter" the frame. When a loop invariant enters,
270 // we make it available to all active iterations. When the frame starts
271 // a new iteration, we make all the current loop invariants available
272 // to the new iteration.
273 std::vector<std::pair<const NodeItem*, Entry>> inv_values TF_GUARDED_BY(mu);
274
275 // The list of dead exit node items for the current highest iteration. We
276 // will only "execute" the dead exits of the final iteration.
277 std::vector<const NodeItem*> dead_exits TF_GUARDED_BY(mu);
278
279 // Static information specific to this frame.
280 PendingCounts* pending_counts = nullptr;
281 int total_input_tensors = 0;
282 std::vector<const NodeItem*>* nodes = nullptr;
283
284 // Lock ordering: ExecutorState.mu_ < mu;
285 // during structured traversal: parent_frame->mu < mu.
286 mutex mu;
287
288 void InitializeFrameInfo(const ImmutableExecutorState::FrameInfo& finfo);
289
GetIterationFrameState290 inline IterationState* GetIteration(int64_t iter)
291 TF_SHARED_LOCKS_REQUIRED(mu) {
292 if (TF_PREDICT_TRUE(iter == 0)) {
293 return iterations_first;
294 } else {
295 size_t index = iter % (max_parallel_iterations + 1);
296 return iterations_raw[index];
297 }
298 }
299
300 void SetIteration(int64_t iter, IterationState* state);
301
302 // Adjust the outstanding op count by 'delta' and clean up the iterations in
303 // the frame if no more ops are oustanding. Return true iff the execution of
304 // the frame is done.
305 //
306 // Avoids acquiring the lock in the common case that the frame is not done.
307 bool AdjustOutstandingOps(IterationState* iter_state, int delta,
308 TaggedNodeSeq* ready);
309
310 bool AdjustOutstandingOpsLocked(IterationState* iter_state, int delta,
311 TaggedNodeSeq* ready)
312 TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
313
314 bool AdjustOutstandingOpsFastPath(IterationState* iter_state, int delta)
315 TF_SHARED_LOCKS_REQUIRED(mu);
316
317 // Convenience methods for the above 'Adjust' calls where delta takes the
318 // common value of -1.
319 bool DecrementOutstandingOps(IterationState* iter_state,
320 TaggedNodeSeq* ready);
321
322 bool DecrementOutstandingOpsLocked(IterationState* iter_state,
323 TaggedNodeSeq* ready);
324
325 // Returns true if the computation in the frame is completed.
326 bool IsFrameDone();
327
328 // Returns true if the iteration of the frame is completed.
329 bool IsIterationDone(IterationState* iter_state)
330 TF_SHARED_LOCKS_REQUIRED(mu);
331
332 // Increments the iteration id. If this is a new iteration, initialize it.
333 //
334 // Returns a pointer to the new iteration.
335 IterationState* IncrementIteration(TaggedNodeSeq* ready)
336 TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
337
338 // Activate all the deferred NextIteration nodes in a new iteration.
339 void ActivateNexts(IterationState* iter_state, TaggedNodeSeq* ready)
340 TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
341
342 // Activate all the current loop invariants in a new iteration.
343 void ActivateLoopInvs(IterationState* iter_state, TaggedNodeSeq* ready)
344 TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
345
346 // Add a new loop invariant and make it available to all active
347 // iterations.
348 void AddLoopInv(const NodeItem* item, const Entry& entry,
349 TaggedNodeSeq* ready) TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
350
351 // Activate the successors of a node. Contents of *outputs are left in an
352 // indeterminate state after returning from this method.
353 //
354 // In the case that 'item' is a simple node (no merge/control outputs) this
355 // will acquire a shared lock and can run concurrently with other
356 // invocations.
357 //
358 // Return true if the frame is done after activation.
359 bool ActivateNodesAndAdjustOutstanding(const NodeItem* item,
360 const bool is_dead,
361 IterationState* iter_state,
362 EntryVector* outputs,
363 TaggedNodeSeq* ready);
364
365 // Same as the above, but requires 'mu' already held in exclusive mode.
366 int ActivateNodesLocked(const NodeItem* item, const bool is_dead,
367 IterationState* iter_state, EntryVector* outputs,
368 TaggedNodeSeq* ready)
369 TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
370
371 // Cleanup iterations of this frame starting from the given iteration.
372 bool CleanupIterations(IterationState* iter_state, TaggedNodeSeq* ready)
373 TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
374
DumpIterationStateFrameState375 void DumpIterationState(PropagatorState* parent) {
376 mutex_lock l(mu);
377 for (IterationState* iteration : iterations) {
378 if (iteration) {
379 LOG(WARNING) << " Iteration:";
380 parent->DumpIterationState(this, iteration);
381 }
382 }
383 }
384
~FrameStateFrameState385 ~FrameState() {
386 for (size_t i = 0; i < iterations.size(); ++i) {
387 delete iterations[i];
388 iterations[i] = nullptr;
389 }
390 }
391
392 private:
393 // REQUIRES: `!item->is_any_consumer_merge_or_control_trigger`.
394 // This variant does not use atomic operations to modify the pending counts
395 // and thus must hold the exclusive lock.
ActivateNodesFastPathLockedFrameState396 int ActivateNodesFastPathLocked(const NodeItem* item, const bool is_dead,
397 IterationState* iter_state,
398 EntryVector* outputs, TaggedNodeSeq* ready)
399 TF_EXCLUSIVE_LOCKS_REQUIRED(mu) {
400 return ActivateNodesFastPathInternal<false>(item, is_dead, iter_state,
401 outputs, ready);
402 }
403
404 // REQUIRES: `!item->is_any_consumer_merge_or_control_trigger`.
405 // This variant uses atomic operations to modify the pending counts.
ActivateNodesFastPathSharedFrameState406 int ActivateNodesFastPathShared(const NodeItem* item, const bool is_dead,
407 IterationState* iter_state,
408 EntryVector* outputs, TaggedNodeSeq* ready)
409 TF_SHARED_LOCKS_REQUIRED(mu) {
410 return ActivateNodesFastPathInternal<true>(item, is_dead, iter_state,
411 outputs, ready);
412 }
413
414 template <bool atomic>
415 int ActivateNodesFastPathInternal(const NodeItem* item, const bool is_dead,
416 IterationState* iter_state,
417 EntryVector* outputs,
418 TaggedNodeSeq* ready);
419
420 int ActivateNodesSlowPath(const NodeItem* item, const bool is_dead,
421 IterationState* iter_state, EntryVector* outputs,
422 TaggedNodeSeq* ready)
423 TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
424 };
425
426 public:
427 // Creates and adds a `TaggedNode` for each node in `roots` to `*ready`.
428 void ActivateRoots(gtl::ArraySlice<const NodeItem*> roots,
429 TaggedNodeSeq* ready);
430
431 // After processing the outputs, propagates the outputs to their dsts.
432 // Contents of *outputs are left in an indeterminate state after
433 // returning from this method.
434 void PropagateOutputs(const TaggedNode& tagged_node, EntryVector* outputs,
435 TaggedNodeSeq* ready);
436
437 // Returns an array of `Entry` objects corresponding to the inputs of
438 // `tagged_node`.
439 //
440 // NOTE: Thread safety analysis is disabled on this method, because the
441 // underlying `IterationState` and its array of `input_tensors` retain the
442 // same address while the iteration is live.
GetInputTensors(const TaggedNode & tagged_node)443 Entry* GetInputTensors(const TaggedNode& tagged_node) const
444 TF_NO_THREAD_SAFETY_ANALYSIS {
445 return tagged_node.input_iter->input_tensors +
446 tagged_node.node_item->input_start;
447 }
448
GetFrameAndIter(const TaggedNode & tagged_node)449 FrameAndIter GetFrameAndIter(const TaggedNode& tagged_node) const {
450 return {tagged_node.input_frame->frame_id,
451 tagged_node.input_iter->iter_num};
452 }
453
454 // Provide debugging output of the state of the executor.
455 void DumpState();
456
457 // For debugging/logging only.
MaybeMarkStarted(const TaggedNode & tagged_node)458 void MaybeMarkStarted(const TaggedNode& tagged_node) {
459 // TODO(misard) Replace with a finer-grain enabling flag once we add better
460 // optional debugging support.
461 if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) {
462 mutex_lock l(tagged_node.input_frame->mu);
463 tagged_node.input_iter->mark_started(
464 immutable_state_.pending_ids()[tagged_node.node_item->node_id]);
465 }
466 }
467
MaybeMarkCompleted(const TaggedNode & tagged_node)468 void MaybeMarkCompleted(const TaggedNode& tagged_node) {
469 // TODO(misard) Replace with a finer-grain enabling flag once we add better
470 // optional debugging support.
471 if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) {
472 mutex_lock l(tagged_node.input_frame->mu);
473 tagged_node.input_iter->mark_completed(
474 immutable_state_.pending_ids()[tagged_node.node_item->node_id]);
475 }
476 }
477
478 private:
479 // Find an existing or create a new child frame in the frame 'frame' at
480 // iteration 'iter'.
481 void FindOrCreateChildFrame(FrameState* frame, IterationState* iter_state,
482 const NodeItem& node_item, FrameState** child);
483
484 // Delete a frame. Called when the frame is done.
485 void DeleteFrame(FrameState* frame, TaggedNodeSeq* ready);
486
487 // Cleanup frames and iterations starting from frame/iter. Called when
488 // a child frame is done.
489 void CleanupFramesIterations(FrameState* frame, IterationState* iter_state,
490 TaggedNodeSeq* ready);
491
492 // Provide debugging output about an outstanding iteration in the executor.
493 void DumpIterationState(const FrameState* frame, IterationState* iteration);
494
495 const ImmutableExecutorState& immutable_state_;
496 const int64 step_id_;
497 const bool vlog_;
498
499 mutex mu_;
500
501 // The root frame in which the execution of this step is started.
502 FrameState* root_frame_;
503
504 // Mapping from frame ID to outstanding frames. A new frame is created
505 // at some iteration of an active frame. So the unique key for the new
506 // child frame is a hash composed of the ID of the parent frame, the iteration
507 // number at which the parent frame is creating the new frame, and the
508 // name of the new frame from nodedef.
509 absl::flat_hash_map<uint64, FrameState*> outstanding_frames_
510 TF_GUARDED_BY(mu_);
511
512 TF_DISALLOW_COPY_AND_ASSIGN(PropagatorState);
513 };
514
get_iter_num()515 inline int64 PropagatorState::TaggedNode::get_iter_num() const {
516 return input_iter->iter_num;
517 }
518
519 } // namespace tensorflow
520
521 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_STATE_H_
522