• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_STATE_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_STATE_H_
17 
18 #include <vector>
19 
20 #include "tensorflow/core/common_runtime/entry.h"
21 #include "tensorflow/core/common_runtime/immutable_executor_state.h"
22 #include "tensorflow/core/common_runtime/pending_counts.h"
23 #include "tensorflow/core/framework/allocator.h"
24 #include "tensorflow/core/framework/control_flow.h"
25 #include "tensorflow/core/lib/gtl/flatmap.h"
26 #include "tensorflow/core/lib/gtl/inlined_vector.h"
27 #include "tensorflow/core/platform/env.h"
28 #include "tensorflow/core/platform/logging.h"
29 #include "tensorflow/core/platform/macros.h"
30 #include "tensorflow/core/platform/mutex.h"
31 #include "tensorflow/core/platform/thread_annotations.h"
32 #include "tensorflow/core/platform/types.h"
33 
34 namespace tensorflow {
35 
36 typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec;
37 
38 // Represents the ephemeral "edge state" associated with one invocation of
39 // `Executor::Run()`.
40 //
41 // `PropagatorState` is responsible for propagating values along dataflow
42 // edges in a TensorFlow graph and determining which nodes are runnable. The
43 // executor primarily updates `PropagatorState` by calling `PropagateOutputs()`
44 // after processing a node, and `PropagatorState` dispatches `TaggedNode`s by
45 // adding them to a `TaggedNodeSeq`.
46 class PropagatorState {
47  public:
48   PropagatorState(const ImmutableExecutorState& immutable_state,
49                   int64_t step_id, bool vlog);
50   ~PropagatorState();
51 
52  private:
53   // Forward declaration so that `TaggedNode` can include a `FrameState*` and an
54   // `IterationState*`.
55   struct FrameState;
56   struct IterationState;
57 
58  public:
59   // A `TaggedNode` corresponds to a single invocation of a node's kernel,
60   // and it is created when the kernel becomes runnable (in a particular
61   // iteration of a particular frame).
62   struct TaggedNode {
63     const NodeItem* node_item;
64     FrameState* input_frame;
65     IterationState* input_iter;
66     bool is_dead;
67 
68     TaggedNode() = default;
TaggedNodeTaggedNode69     TaggedNode(const NodeItem* node_item, FrameState* in_frame,
70                IterationState* in_iter, bool dead)
71         : node_item(node_item),
72           input_frame(in_frame),
73           input_iter(in_iter),
74           is_dead(dead) {}
75 
get_node_itemTaggedNode76     const NodeItem& get_node_item() const { return *node_item; }
77 
get_is_deadTaggedNode78     bool get_is_dead() const { return is_dead; }
79     int64 get_iter_num() const;
80   };
81 
82   // A drop-in replacement for std::deque<TaggedNode>.  We typically don't
83   // have that many nodes in the ready queue, so we just use a vector and
84   // don't free up memory from the queue as we consume nodes.
85   class TaggedNodeReadyQueue {
86    public:
TaggedNodeReadyQueue()87     TaggedNodeReadyQueue() : front_index_(0) {}
88 
push_back(const TaggedNode & node)89     void push_back(const TaggedNode& node) { ready_.push_back(node); }
front()90     TaggedNode front() const {
91       DCHECK_LT(front_index_, ready_.size());
92       return ready_[front_index_];
93     }
pop_front()94     void pop_front() {
95       DCHECK_LT(front_index_, ready_.size());
96       front_index_++;
97       if ((front_index_ == ready_.size()) || (front_index_ > kSpillThreshold)) {
98         if (front_index_ == ready_.size()) {
99           ready_.clear();
100         } else {
101           // Lots of unused entries at beginning of vector: move everything
102           // down to start of vector.
103           ready_.erase(ready_.begin(), ready_.begin() + front_index_);
104         }
105         front_index_ = 0;
106       }
107     }
empty()108     bool empty() const { return ready_.empty(); }
109 
110    private:
111     // TODO(b/152925936): Re-evaluate these constants with current usage
112     // patterns.
113     static constexpr int kSpillThreshold = 16384;
114     gtl::InlinedVector<TaggedNode, 16> ready_;
115     int front_index_;
116   };
117 
118   // TODO(b/152925936): Re-evaluate this constant with current usage patterns.
119   typedef gtl::InlinedVector<TaggedNode, 8> TaggedNodeSeq;
120 
121  private:
122   // The state of an iteration in a particular frame.
123   struct IterationState {
IterationStateIterationState124     explicit IterationState(int64_t iter_num,
125                             const PendingCounts* pending_counts,
126                             int total_input_tensors)
127         : iter_num(iter_num),
128           input_tensors(new Entry[total_input_tensors]),
129           outstanding_ops(0),
130           outstanding_frame_count(0),
131           counts(*pending_counts) {  // Initialize with copy of *pending_counts
132     }
133 
134     const int64 iter_num;  // The index of this iteration in the enclosing loop.
135 
136     // One copy per iteration. For iteration k, i-th node's j-th input is in
137     // input_tensors[k][immutable_state_.nodes[i].input_start + j]. An entry is
138     // either a tensor pointer (pass-by-reference) or a tensor (pass-by-value).
139     //
140     // NOTE: No need to protect input_tensors[i] by any locks because it
141     // is resized once. Each element of tensors_ is written once by the
142     // source node of an edge and is cleared by the destination of the same
143     // edge. The latter node is never run concurrently with the former node.
144     Entry* input_tensors;
145 
146     // The number of outstanding ops for each iteration.
147     std::atomic<size_t> outstanding_ops;
148 
149     // The number of outstanding frames for each iteration.
150     int outstanding_frame_count;
pendingIterationState151     int pending(PendingCounts::Handle h) { return counts.pending(h); }
decrement_pendingIterationState152     int decrement_pending(PendingCounts::Handle h, int v) {
153       return counts.decrement_pending(h, v);
154     }
155     // Mark a merge node as live
156     // REQUIRES: Node corresponding to "h" is a merge node
mark_liveIterationState157     void mark_live(PendingCounts::Handle h) { counts.mark_live(h); }
158     // Mark a node to show that processing has started.
mark_startedIterationState159     void mark_started(PendingCounts::Handle h) { counts.mark_started(h); }
160     // Mark a node to show that processing has completed.
mark_completedIterationState161     void mark_completed(PendingCounts::Handle h) { counts.mark_completed(h); }
node_stateIterationState162     PendingCounts::NodeState node_state(PendingCounts::Handle h) {
163       return counts.node_state(h);
164     }
165 
dead_countIterationState166     int dead_count(PendingCounts::Handle h) { return counts.dead_count(h); }
increment_dead_countIterationState167     void increment_dead_count(PendingCounts::Handle h) {
168       counts.increment_dead_count(h);
169     }
adjust_for_activationIterationState170     PendingCounts::AdjustResult adjust_for_activation(PendingCounts::Handle h,
171                                                       bool increment_dead) {
172       return counts.adjust_for_activation(h, increment_dead);
173     }
adjust_for_activation_atomicIterationState174     PendingCounts::AdjustResult adjust_for_activation_atomic(
175         PendingCounts::Handle h, bool increment_dead) {
176       return counts.adjust_for_activation_atomic(h, increment_dead);
177     }
178 
~IterationStateIterationState179     ~IterationState() { delete[] input_tensors; }
180 
181    private:
182     PendingCounts counts;
183   };
184 
185   struct FrameState {
FrameStateFrameState186     explicit FrameState(const ImmutableExecutorState& immutable_state,
187                         int parallel_iters)
188         : immutable_state(immutable_state),
189           max_parallel_iterations(parallel_iters),
190           num_outstanding_iterations(1),
191           iterations(parallel_iters + 1),
192           iterations_raw(iterations.data()) {}
193 
194     // A new frame is created for each loop. Execution starts at iteration 0.
195     // When a value at iteration 0 passes through a NextIteration node,
196     // iteration 1 is created and starts running. Note that iteration 0 may
197     // still be running so multiple iterations may run in parallel. The
198     // frame maintains the state of iterations in several data structures
199     // such as pending_count and input_tensors. When iteration 0 completes,
200     // we garbage collect the state of iteration 0.
201     //
202     // A frame instance is considered "done" and can be garbage collected
203     // if all its inputs have entered and all its iterations are "done".
204     //
205     // A frame manages the live iterations of an iterative computation.
206     // Iteration i is considered "done" when there are no outstanding ops,
207     // frames at iteration i are done, all recvs for this iteration are
208     // completed, and iteration i-1 is done. For iteration 0, we instead
209     // wait for there to be no more pending inputs of the frame.
210     //
211     // Frames and iterations are garbage collected once they are done.
212     // The state we need to keep around is highly dependent on the
213     // parallelism enabled by the scheduler. We may want to have the
214     // scheduler dynamically control the outstanding number of live
215     // parallel frames and iterations. To reduce the state space, the
216     // scheduler might want to schedule ops in inner frames first and
217     // lower iterations first.
218     //
219     // This frame state is mostly initialized lazily on demand so we
220     // don't introduce unnecessary overhead.
221 
222     // The immutable state of the executor the frame is in.
223     const ImmutableExecutorState& immutable_state;
224 
225     // The name of this frame, which is the concatenation of its parent
226     // frame name, the iteration of the parent frame when this frame was
227     // created, and the value of the attr 'frame_name'.
228     string frame_name;
229 
230     // The unique id for this frame. Generated by fingerprinting
231     // frame_name.
232     uint64 frame_id;
233 
234     // The iteration state of its parent frame when this frame is created.
235     // nullptr if there is no parent frame. The frame_name/parent_iter pair
236     // uniquely identifies this FrameState.
237     IterationState* parent_iter = nullptr;
238 
239     // The FrameState of its parent frame.
240     FrameState* parent_frame = nullptr;
241 
242     // The maximum allowed number of parallel iterations.
243     const int max_parallel_iterations;
244 
245     // The number of inputs this frame is still waiting.
246     int num_pending_inputs = 0;
247 
248     // The highest iteration number we have reached so far in this frame.
249     int64 iteration_count TF_GUARDED_BY(mu) = 0;
250 
251     // The number of outstanding iterations.
252     int num_outstanding_iterations TF_GUARDED_BY(mu) = 1;
253 
254    private:
255     // The active iteration states of this frame.
256     gtl::InlinedVector<IterationState*, 12> iterations;
257     IterationState** const iterations_raw TF_GUARDED_BY(mu);
258     IterationState* iterations_first TF_GUARDED_BY(mu);
259 
260    public:
261     // The NextIteration nodes to enter a new iteration. If the number of
262     // outstanding iterations reaches the limit, we will defer the start of
263     // the next iteration until the number of outstanding iterations falls
264     // below the limit.
265     std::vector<std::pair<const NodeItem*, Entry>> next_iter_roots
266         TF_GUARDED_BY(mu);
267 
268     // The values of the loop invariants for this loop. They are added into
269     // this list as they "enter" the frame. When a loop invariant enters,
270     // we make it available to all active iterations. When the frame starts
271     // a new iteration, we make all the current loop invariants available
272     // to the new iteration.
273     std::vector<std::pair<const NodeItem*, Entry>> inv_values TF_GUARDED_BY(mu);
274 
275     // The list of dead exit node items for the current highest iteration. We
276     // will only "execute" the dead exits of the final iteration.
277     std::vector<const NodeItem*> dead_exits TF_GUARDED_BY(mu);
278 
279     // Static information specific to this frame.
280     PendingCounts* pending_counts = nullptr;
281     int total_input_tensors = 0;
282     std::vector<const NodeItem*>* nodes = nullptr;
283 
284     // Lock ordering: ExecutorState.mu_ < mu;
285     // during structured traversal: parent_frame->mu < mu.
286     mutex mu;
287 
288     void InitializeFrameInfo(const ImmutableExecutorState::FrameInfo& finfo);
289 
GetIterationFrameState290     inline IterationState* GetIteration(int64_t iter)
291         TF_SHARED_LOCKS_REQUIRED(mu) {
292       if (TF_PREDICT_TRUE(iter == 0)) {
293         return iterations_first;
294       } else {
295         size_t index = iter % (max_parallel_iterations + 1);
296         return iterations_raw[index];
297       }
298     }
299 
300     void SetIteration(int64_t iter, IterationState* state);
301 
302     // Adjust the outstanding op count by 'delta' and clean up the iterations in
303     // the frame if no more ops are oustanding. Return true iff the execution of
304     // the frame is done.
305     //
306     // Avoids acquiring the lock in the common case that the frame is not done.
307     bool AdjustOutstandingOps(IterationState* iter_state, int delta,
308                               TaggedNodeSeq* ready);
309 
310     bool AdjustOutstandingOpsLocked(IterationState* iter_state, int delta,
311                                     TaggedNodeSeq* ready)
312         TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
313 
314     bool AdjustOutstandingOpsFastPath(IterationState* iter_state, int delta)
315         TF_SHARED_LOCKS_REQUIRED(mu);
316 
317     // Convenience methods for the above 'Adjust' calls where delta takes the
318     // common value of -1.
319     bool DecrementOutstandingOps(IterationState* iter_state,
320                                  TaggedNodeSeq* ready);
321 
322     bool DecrementOutstandingOpsLocked(IterationState* iter_state,
323                                        TaggedNodeSeq* ready);
324 
325     // Returns true if the computation in the frame is completed.
326     bool IsFrameDone();
327 
328     // Returns true if the iteration of the frame is completed.
329     bool IsIterationDone(IterationState* iter_state)
330         TF_SHARED_LOCKS_REQUIRED(mu);
331 
332     // Increments the iteration id. If this is a new iteration, initialize it.
333     //
334     // Returns a pointer to the new iteration.
335     IterationState* IncrementIteration(TaggedNodeSeq* ready)
336         TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
337 
338     // Activate all the deferred NextIteration nodes in a new iteration.
339     void ActivateNexts(IterationState* iter_state, TaggedNodeSeq* ready)
340         TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
341 
342     // Activate all the current loop invariants in a new iteration.
343     void ActivateLoopInvs(IterationState* iter_state, TaggedNodeSeq* ready)
344         TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
345 
346     // Add a new loop invariant and make it available to all active
347     // iterations.
348     void AddLoopInv(const NodeItem* item, const Entry& entry,
349                     TaggedNodeSeq* ready) TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
350 
351     // Activate the successors of a node. Contents of *outputs are left in an
352     // indeterminate state after returning from this method.
353     //
354     // In the case that 'item' is a simple node (no merge/control outputs) this
355     // will acquire a shared lock and can run concurrently with other
356     // invocations.
357     //
358     // Return true if the frame is done after activation.
359     bool ActivateNodesAndAdjustOutstanding(const NodeItem* item,
360                                            const bool is_dead,
361                                            IterationState* iter_state,
362                                            EntryVector* outputs,
363                                            TaggedNodeSeq* ready);
364 
365     // Same as the above, but requires 'mu' already held in exclusive mode.
366     int ActivateNodesLocked(const NodeItem* item, const bool is_dead,
367                             IterationState* iter_state, EntryVector* outputs,
368                             TaggedNodeSeq* ready)
369         TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
370 
371     // Cleanup iterations of this frame starting from the given iteration.
372     bool CleanupIterations(IterationState* iter_state, TaggedNodeSeq* ready)
373         TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
374 
DumpIterationStateFrameState375     void DumpIterationState(PropagatorState* parent) {
376       mutex_lock l(mu);
377       for (IterationState* iteration : iterations) {
378         if (iteration) {
379           LOG(WARNING) << "  Iteration:";
380           parent->DumpIterationState(this, iteration);
381         }
382       }
383     }
384 
~FrameStateFrameState385     ~FrameState() {
386       for (size_t i = 0; i < iterations.size(); ++i) {
387         delete iterations[i];
388         iterations[i] = nullptr;
389       }
390     }
391 
392    private:
393     // REQUIRES: `!item->is_any_consumer_merge_or_control_trigger`.
394     // This variant does not use atomic operations to modify the pending counts
395     // and thus must hold the exclusive lock.
ActivateNodesFastPathLockedFrameState396     int ActivateNodesFastPathLocked(const NodeItem* item, const bool is_dead,
397                                     IterationState* iter_state,
398                                     EntryVector* outputs, TaggedNodeSeq* ready)
399         TF_EXCLUSIVE_LOCKS_REQUIRED(mu) {
400       return ActivateNodesFastPathInternal<false>(item, is_dead, iter_state,
401                                                   outputs, ready);
402     }
403 
404     // REQUIRES: `!item->is_any_consumer_merge_or_control_trigger`.
405     // This variant uses atomic operations to modify the pending counts.
ActivateNodesFastPathSharedFrameState406     int ActivateNodesFastPathShared(const NodeItem* item, const bool is_dead,
407                                     IterationState* iter_state,
408                                     EntryVector* outputs, TaggedNodeSeq* ready)
409         TF_SHARED_LOCKS_REQUIRED(mu) {
410       return ActivateNodesFastPathInternal<true>(item, is_dead, iter_state,
411                                                  outputs, ready);
412     }
413 
414     template <bool atomic>
415     int ActivateNodesFastPathInternal(const NodeItem* item, const bool is_dead,
416                                       IterationState* iter_state,
417                                       EntryVector* outputs,
418                                       TaggedNodeSeq* ready);
419 
420     int ActivateNodesSlowPath(const NodeItem* item, const bool is_dead,
421                               IterationState* iter_state, EntryVector* outputs,
422                               TaggedNodeSeq* ready)
423         TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
424   };
425 
426  public:
427   // Creates and adds a `TaggedNode` for each node in `roots` to `*ready`.
428   void ActivateRoots(gtl::ArraySlice<const NodeItem*> roots,
429                      TaggedNodeSeq* ready);
430 
431   // After processing the outputs, propagates the outputs to their dsts.
432   // Contents of *outputs are left in an indeterminate state after
433   // returning from this method.
434   void PropagateOutputs(const TaggedNode& tagged_node, EntryVector* outputs,
435                         TaggedNodeSeq* ready);
436 
437   // Returns an array of `Entry` objects corresponding to the inputs of
438   // `tagged_node`.
439   //
440   // NOTE: Thread safety analysis is disabled on this method, because the
441   // underlying `IterationState` and its array of `input_tensors` retain the
442   // same address while the iteration is live.
GetInputTensors(const TaggedNode & tagged_node)443   Entry* GetInputTensors(const TaggedNode& tagged_node) const
444       TF_NO_THREAD_SAFETY_ANALYSIS {
445     return tagged_node.input_iter->input_tensors +
446            tagged_node.node_item->input_start;
447   }
448 
GetFrameAndIter(const TaggedNode & tagged_node)449   FrameAndIter GetFrameAndIter(const TaggedNode& tagged_node) const {
450     return {tagged_node.input_frame->frame_id,
451             tagged_node.input_iter->iter_num};
452   }
453 
454   // Provide debugging output of the state of the executor.
455   void DumpState();
456 
457   // For debugging/logging only.
MaybeMarkStarted(const TaggedNode & tagged_node)458   void MaybeMarkStarted(const TaggedNode& tagged_node) {
459     // TODO(misard) Replace with a finer-grain enabling flag once we add better
460     // optional debugging support.
461     if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) {
462       mutex_lock l(tagged_node.input_frame->mu);
463       tagged_node.input_iter->mark_started(
464           immutable_state_.pending_ids()[tagged_node.node_item->node_id]);
465     }
466   }
467 
MaybeMarkCompleted(const TaggedNode & tagged_node)468   void MaybeMarkCompleted(const TaggedNode& tagged_node) {
469     // TODO(misard) Replace with a finer-grain enabling flag once we add better
470     // optional debugging support.
471     if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) {
472       mutex_lock l(tagged_node.input_frame->mu);
473       tagged_node.input_iter->mark_completed(
474           immutable_state_.pending_ids()[tagged_node.node_item->node_id]);
475     }
476   }
477 
478  private:
479   // Find an existing or create a new child frame in the frame 'frame' at
480   // iteration 'iter'.
481   void FindOrCreateChildFrame(FrameState* frame, IterationState* iter_state,
482                               const NodeItem& node_item, FrameState** child);
483 
484   // Delete a frame. Called when the frame is done.
485   void DeleteFrame(FrameState* frame, TaggedNodeSeq* ready);
486 
487   // Cleanup frames and iterations starting from frame/iter. Called when
488   // a child frame is done.
489   void CleanupFramesIterations(FrameState* frame, IterationState* iter_state,
490                                TaggedNodeSeq* ready);
491 
492   // Provide debugging output about an outstanding iteration in the executor.
493   void DumpIterationState(const FrameState* frame, IterationState* iteration);
494 
495   const ImmutableExecutorState& immutable_state_;
496   const int64 step_id_;
497   const bool vlog_;
498 
499   mutex mu_;
500 
501   // The root frame in which the execution of this step is started.
502   FrameState* root_frame_;
503 
504   // Mapping from frame ID to outstanding frames. A new frame is created
505   // at some iteration of an active frame. So the unique key for the new
506   // child frame is a hash composed of the ID of the parent frame, the iteration
507   // number at which the parent frame is creating the new frame, and the
508   // name of the new frame from nodedef.
509   absl::flat_hash_map<uint64, FrameState*> outstanding_frames_
510       TF_GUARDED_BY(mu_);
511 
512   TF_DISALLOW_COPY_AND_ASSIGN(PropagatorState);
513 };
514 
get_iter_num()515 inline int64 PropagatorState::TaggedNode::get_iter_num() const {
516   return input_iter->iter_num;
517 }
518 
519 }  // namespace tensorflow
520 
521 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_STATE_H_
522