• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_STATE_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_STATE_H_
17 
18 #include <vector>
19 
20 #include "tensorflow/core/common_runtime/entry.h"
21 #include "tensorflow/core/common_runtime/immutable_executor_state.h"
22 #include "tensorflow/core/common_runtime/pending_counts.h"
23 #include "tensorflow/core/framework/allocator.h"
24 #include "tensorflow/core/framework/control_flow.h"
25 #include "tensorflow/core/lib/gtl/flatmap.h"
26 #include "tensorflow/core/lib/gtl/inlined_vector.h"
27 #include "tensorflow/core/platform/env.h"
28 #include "tensorflow/core/platform/logging.h"
29 #include "tensorflow/core/platform/macros.h"
30 #include "tensorflow/core/platform/mutex.h"
31 #include "tensorflow/core/platform/thread_annotations.h"
32 #include "tensorflow/core/platform/types.h"
33 
34 namespace tensorflow {
35 
36 typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec;
37 
38 // Represents the ephemeral "edge state" associated with one invocation of
39 // `Executor::Run()`.
40 //
41 // `PropagatorState` is responsible for propagating values along dataflow
42 // edges in a TensorFlow graph and determining which nodes are runnable. The
43 // executor primarily updates `PropagatorState` by calling `PropagateOutputs()`
44 // after processing a node, and `PropagatorState` dispatches `TaggedNode`s by
45 // adding them to a `TaggedNodeSeq`.
46 class PropagatorState {
47  public:
48   PropagatorState(const ImmutableExecutorState& immutable_state, int64 step_id,
49                   bool vlog);
50   ~PropagatorState();
51 
52  private:
53   // Forward declaration so that `TaggedNode` can include a `FrameState*` and an
54   // `IterationState*`.
55   struct FrameState;
56   struct IterationState;
57 
58  public:
59   // A `TaggedNode` corresponds to a single invocation of a node's kernel,
60   // and it is created when the kernel becomes runnable (in a particular
61   // iteration of a particular frame).
62   struct TaggedNode {
63     const NodeItem* node_item;
64     FrameState* input_frame;
65     IterationState* input_iter;
66     bool is_dead;
67 
68     TaggedNode() = default;
TaggedNodeTaggedNode69     TaggedNode(const NodeItem* node_item, FrameState* in_frame,
70                IterationState* in_iter, bool dead)
71         : node_item(node_item),
72           input_frame(in_frame),
73           input_iter(in_iter),
74           is_dead(dead) {}
75 
get_node_itemTaggedNode76     const NodeItem& get_node_item() const { return *node_item; }
77 
get_is_deadTaggedNode78     bool get_is_dead() const { return is_dead; }
79     int64 get_iter_num() const;
80   };
81 
82   // A drop-in replacement for std::deque<TaggedNode>.  We typically don't
83   // have that many nodes in the ready queue, so we just use a vector and
84   // don't free up memory from the queue as we consume nodes.
85   class TaggedNodeReadyQueue {
86    public:
TaggedNodeReadyQueue()87     TaggedNodeReadyQueue() : front_index_(0) {}
88 
push_back(const TaggedNode & node)89     void push_back(const TaggedNode& node) { ready_.push_back(node); }
front()90     TaggedNode front() const {
91       DCHECK_LT(front_index_, ready_.size());
92       return ready_[front_index_];
93     }
pop_front()94     void pop_front() {
95       DCHECK_LT(front_index_, ready_.size());
96       front_index_++;
97       if ((front_index_ == ready_.size()) || (front_index_ > kSpillThreshold)) {
98         if (front_index_ == ready_.size()) {
99           ready_.clear();
100         } else {
101           // Lots of unused entries at beginning of vector: move everything
102           // down to start of vector.
103           ready_.erase(ready_.begin(), ready_.begin() + front_index_);
104         }
105         front_index_ = 0;
106       }
107     }
empty()108     bool empty() const { return ready_.empty(); }
109 
110    private:
111     // TODO(b/152925936): Re-evaluate these constants with current usage
112     // patterns.
113     static constexpr int kSpillThreshold = 16384;
114     gtl::InlinedVector<TaggedNode, 16> ready_;
115     int front_index_;
116   };
117 
118   // TODO(b/152925936): Re-evaluate this constant with current usage patterns.
119   typedef gtl::InlinedVector<TaggedNode, 8> TaggedNodeSeq;
120 
121  private:
122   // The state of an iteration in a particular frame.
123   struct IterationState {
IterationStateIterationState124     explicit IterationState(int64 iter_num, const PendingCounts* pending_counts,
125                             int total_input_tensors)
126         : iter_num(iter_num),
127           input_tensors(new Entry[total_input_tensors]),
128           outstanding_ops(0),
129           outstanding_frame_count(0),
130           counts(*pending_counts) {  // Initialize with copy of *pending_counts
131     }
132 
133     const int64 iter_num;  // The index of this iteration in the enclosing loop.
134 
135     // One copy per iteration. For iteration k, i-th node's j-th input is in
136     // input_tensors[k][immutable_state_.nodes[i].input_start + j]. An entry is
137     // either a tensor pointer (pass-by-reference) or a tensor (pass-by-value).
138     //
139     // NOTE: No need to protect input_tensors[i] by any locks because it
140     // is resized once. Each element of tensors_ is written once by the
141     // source node of an edge and is cleared by the destination of the same
142     // edge. The latter node is never run concurrently with the former node.
143     Entry* input_tensors;
144 
145     // The number of outstanding ops for each iteration.
146     std::atomic<size_t> outstanding_ops;
147 
148     // The number of outstanding frames for each iteration.
149     int outstanding_frame_count;
pendingIterationState150     int pending(PendingCounts::Handle h) { return counts.pending(h); }
decrement_pendingIterationState151     int decrement_pending(PendingCounts::Handle h, int v) {
152       return counts.decrement_pending(h, v);
153     }
154     // Mark a merge node as live
155     // REQUIRES: Node corresponding to "h" is a merge node
mark_liveIterationState156     void mark_live(PendingCounts::Handle h) { counts.mark_live(h); }
157     // Mark a node to show that processing has started.
mark_startedIterationState158     void mark_started(PendingCounts::Handle h) { counts.mark_started(h); }
159     // Mark a node to show that processing has completed.
mark_completedIterationState160     void mark_completed(PendingCounts::Handle h) { counts.mark_completed(h); }
node_stateIterationState161     PendingCounts::NodeState node_state(PendingCounts::Handle h) {
162       return counts.node_state(h);
163     }
164 
dead_countIterationState165     int dead_count(PendingCounts::Handle h) { return counts.dead_count(h); }
increment_dead_countIterationState166     void increment_dead_count(PendingCounts::Handle h) {
167       counts.increment_dead_count(h);
168     }
adjust_for_activationIterationState169     PendingCounts::AdjustResult adjust_for_activation(PendingCounts::Handle h,
170                                                       bool increment_dead) {
171       return counts.adjust_for_activation(h, increment_dead);
172     }
adjust_for_activation_atomicIterationState173     PendingCounts::AdjustResult adjust_for_activation_atomic(
174         PendingCounts::Handle h, bool increment_dead) {
175       return counts.adjust_for_activation_atomic(h, increment_dead);
176     }
177 
~IterationStateIterationState178     ~IterationState() { delete[] input_tensors; }
179 
180    private:
181     PendingCounts counts;
182   };
183 
184   struct FrameState {
FrameStateFrameState185     explicit FrameState(const ImmutableExecutorState& immutable_state,
186                         int parallel_iters)
187         : immutable_state(immutable_state),
188           max_parallel_iterations(parallel_iters),
189           num_outstanding_iterations(1),
190           iterations(parallel_iters + 1),
191           iterations_raw(iterations.data()) {}
192 
193     // A new frame is created for each loop. Execution starts at iteration 0.
194     // When a value at iteration 0 passes through a NextIteration node,
195     // iteration 1 is created and starts running. Note that iteration 0 may
196     // still be running so multiple iterations may run in parallel. The
197     // frame maintains the state of iterations in several data structures
198     // such as pending_count and input_tensors. When iteration 0 completes,
199     // we garbage collect the state of iteration 0.
200     //
201     // A frame instance is considered "done" and can be garbage collected
202     // if all its inputs have entered and all its iterations are "done".
203     //
204     // A frame manages the live iterations of an iterative computation.
205     // Iteration i is considered "done" when there are no outstanding ops,
206     // frames at iteration i are done, all recvs for this iteration are
207     // completed, and iteration i-1 is done. For iteration 0, we instead
208     // wait for there to be no more pending inputs of the frame.
209     //
210     // Frames and iterations are garbage collected once they are done.
211     // The state we need to keep around is highly dependent on the
212     // parallelism enabled by the scheduler. We may want to have the
213     // scheduler dynamically control the outstanding number of live
214     // parallel frames and iterations. To reduce the state space, the
215     // scheduler might want to schedule ops in inner frames first and
216     // lower iterations first.
217     //
218     // This frame state is mostly initialized lazily on demand so we
219     // don't introduce unnecessary overhead.
220 
221     // The immutable state of the executor the frame is in.
222     const ImmutableExecutorState& immutable_state;
223 
224     // The name of this frame, which is the concatenation of its parent
225     // frame name, the iteration of the parent frame when this frame was
226     // created, and the value of the attr 'frame_name'.
227     string frame_name;
228 
229     // The unique id for this frame. Generated by fingerprinting
230     // frame_name.
231     uint64 frame_id;
232 
233     // The iteration state of its parent frame when this frame is created.
234     // nullptr if there is no parent frame. The frame_name/parent_iter pair
235     // uniquely identifies this FrameState.
236     IterationState* parent_iter = nullptr;
237 
238     // The FrameState of its parent frame.
239     FrameState* parent_frame = nullptr;
240 
241     // The maximum allowed number of parallel iterations.
242     const int max_parallel_iterations;
243 
244     // The number of inputs this frame is still waiting.
245     int num_pending_inputs = 0;
246 
247     // The highest iteration number we have reached so far in this frame.
248     int64 iteration_count TF_GUARDED_BY(mu) = 0;
249 
250     // The number of outstanding iterations.
251     int num_outstanding_iterations TF_GUARDED_BY(mu) = 1;
252 
253    private:
254     // The active iteration states of this frame.
255     gtl::InlinedVector<IterationState*, 12> iterations;
256     IterationState** const iterations_raw TF_GUARDED_BY(mu);
257     IterationState* iterations_first TF_GUARDED_BY(mu);
258 
259    public:
260     // The NextIteration nodes to enter a new iteration. If the number of
261     // outstanding iterations reaches the limit, we will defer the start of
262     // the next iteration until the number of outstanding iterations falls
263     // below the limit.
264     std::vector<std::pair<const NodeItem*, Entry>> next_iter_roots
265         TF_GUARDED_BY(mu);
266 
267     // The values of the loop invariants for this loop. They are added into
268     // this list as they "enter" the frame. When a loop invariant enters,
269     // we make it available to all active iterations. When the frame starts
270     // a new iteration, we make all the current loop invariants available
271     // to the new iteration.
272     std::vector<std::pair<const NodeItem*, Entry>> inv_values TF_GUARDED_BY(mu);
273 
274     // The list of dead exit node items for the current highest iteration. We
275     // will only "execute" the dead exits of the final iteration.
276     std::vector<const NodeItem*> dead_exits TF_GUARDED_BY(mu);
277 
278     // Static information specific to this frame.
279     PendingCounts* pending_counts = nullptr;
280     int total_input_tensors = 0;
281     std::vector<const NodeItem*>* nodes = nullptr;
282 
283     // Lock ordering: ExecutorState.mu_ < mu;
284     // during structured traversal: parent_frame->mu < mu.
285     mutex mu;
286 
287     void InitializeFrameInfo(const ImmutableExecutorState::FrameInfo& finfo);
288 
GetIterationFrameState289     inline IterationState* GetIteration(int64 iter)
290         TF_SHARED_LOCKS_REQUIRED(mu) {
291       if (TF_PREDICT_TRUE(iter == 0)) {
292         return iterations_first;
293       } else {
294         size_t index = iter % (max_parallel_iterations + 1);
295         return iterations_raw[index];
296       }
297     }
298 
299     void SetIteration(int64 iter, IterationState* state);
300 
301     // Adjust the outstanding op count by 'delta' and clean up the iterations in
302     // the frame if no more ops are oustanding. Return true iff the execution of
303     // the frame is done.
304     //
305     // Avoids acquiring the lock in the common case that the frame is not done.
306     bool AdjustOutstandingOps(IterationState* iter_state, int delta,
307                               TaggedNodeSeq* ready);
308 
309     bool AdjustOutstandingOpsLocked(IterationState* iter_state, int delta,
310                                     TaggedNodeSeq* ready)
311         TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
312 
313     bool AdjustOutstandingOpsFastPath(IterationState* iter_state, int delta)
314         TF_SHARED_LOCKS_REQUIRED(mu);
315 
316     // Convenience methods for the above 'Adjust' calls where delta takes the
317     // common value of -1.
318     bool DecrementOutstandingOps(IterationState* iter_state,
319                                  TaggedNodeSeq* ready);
320 
321     bool DecrementOutstandingOpsLocked(IterationState* iter_state,
322                                        TaggedNodeSeq* ready);
323 
324     // Returns true if the computation in the frame is completed.
325     bool IsFrameDone();
326 
327     // Returns true if the iteration of the frame is completed.
328     bool IsIterationDone(IterationState* iter_state)
329         TF_SHARED_LOCKS_REQUIRED(mu);
330 
331     // Increments the iteration id. If this is a new iteration, initialize it.
332     //
333     // Returns a pointer to the new iteration.
334     IterationState* IncrementIteration(TaggedNodeSeq* ready)
335         TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
336 
337     // Activate all the deferred NextIteration nodes in a new iteration.
338     void ActivateNexts(IterationState* iter_state, TaggedNodeSeq* ready)
339         TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
340 
341     // Activate all the current loop invariants in a new iteration.
342     void ActivateLoopInvs(IterationState* iter_state, TaggedNodeSeq* ready)
343         TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
344 
345     // Add a new loop invariant and make it available to all active
346     // iterations.
347     void AddLoopInv(const NodeItem* item, const Entry& entry,
348                     TaggedNodeSeq* ready) TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
349 
350     // Activate the successors of a node. Contents of *outputs are left in an
351     // indeterminate state after returning from this method.
352     //
353     // In the case that 'item' is a simple node (no merge/control outputs) this
354     // will acquire a shared lock and can run concurrently with other
355     // invocations.
356     //
357     // Return true if the frame is done after activation.
358     bool ActivateNodesAndAdjustOutstanding(const NodeItem* item,
359                                            const bool is_dead,
360                                            IterationState* iter_state,
361                                            EntryVector* outputs,
362                                            TaggedNodeSeq* ready);
363 
364     // Same as the above, but requires 'mu' already held in exclusive mode.
365     int ActivateNodesLocked(const NodeItem* item, const bool is_dead,
366                             IterationState* iter_state, EntryVector* outputs,
367                             TaggedNodeSeq* ready)
368         TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
369 
370     // Cleanup iterations of this frame starting from the given iteration.
371     bool CleanupIterations(IterationState* iter_state, TaggedNodeSeq* ready)
372         TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
373 
DumpIterationStateFrameState374     void DumpIterationState(PropagatorState* parent) {
375       mutex_lock l(mu);
376       for (IterationState* iteration : iterations) {
377         if (iteration) {
378           LOG(WARNING) << "  Iteration:";
379           parent->DumpIterationState(this, iteration);
380         }
381       }
382     }
383 
~FrameStateFrameState384     ~FrameState() {
385       for (size_t i = 0; i < iterations.size(); ++i) {
386         delete iterations[i];
387         iterations[i] = nullptr;
388       }
389     }
390 
391    private:
392     // REQUIRES: `!item->is_any_consumer_merge_or_control_trigger`.
393     // This variant does not use atomic operations to modify the pending counts
394     // and thus must hold the exclusive lock.
ActivateNodesFastPathLockedFrameState395     int ActivateNodesFastPathLocked(const NodeItem* item, const bool is_dead,
396                                     IterationState* iter_state,
397                                     EntryVector* outputs, TaggedNodeSeq* ready)
398         TF_EXCLUSIVE_LOCKS_REQUIRED(mu) {
399       return ActivateNodesFastPathInternal<false>(item, is_dead, iter_state,
400                                                   outputs, ready);
401     }
402 
403     // REQUIRES: `!item->is_any_consumer_merge_or_control_trigger`.
404     // This variant uses atomic operations to modify the pending counts.
ActivateNodesFastPathSharedFrameState405     int ActivateNodesFastPathShared(const NodeItem* item, const bool is_dead,
406                                     IterationState* iter_state,
407                                     EntryVector* outputs, TaggedNodeSeq* ready)
408         TF_SHARED_LOCKS_REQUIRED(mu) {
409       return ActivateNodesFastPathInternal<true>(item, is_dead, iter_state,
410                                                  outputs, ready);
411     }
412 
413     template <bool atomic>
414     int ActivateNodesFastPathInternal(const NodeItem* item, const bool is_dead,
415                                       IterationState* iter_state,
416                                       EntryVector* outputs,
417                                       TaggedNodeSeq* ready);
418 
419     int ActivateNodesSlowPath(const NodeItem* item, const bool is_dead,
420                               IterationState* iter_state, EntryVector* outputs,
421                               TaggedNodeSeq* ready)
422         TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
423   };
424 
425  public:
426   // Creates and adds a `TaggedNode` for each node in `roots` to `*ready`.
427   void ActivateRoots(gtl::ArraySlice<const NodeItem*> roots,
428                      TaggedNodeSeq* ready);
429 
430   // After processing the outputs, propagates the outputs to their dsts.
431   // Contents of *outputs are left in an indeterminate state after
432   // returning from this method.
433   void PropagateOutputs(const TaggedNode& tagged_node, EntryVector* outputs,
434                         TaggedNodeSeq* ready);
435 
436   // Returns an array of `Entry` objects corresponding to the inputs of
437   // `tagged_node`.
438   //
439   // NOTE: Thread safety analysis is disabled on this method, because the
440   // underlying `IterationState` and its array of `input_tensors` retain the
441   // same address while the iteration is live.
GetInputTensors(const TaggedNode & tagged_node)442   Entry* GetInputTensors(const TaggedNode& tagged_node) const
443       TF_NO_THREAD_SAFETY_ANALYSIS {
444     return tagged_node.input_iter->input_tensors +
445            tagged_node.node_item->input_start;
446   }
447 
GetFrameAndIter(const TaggedNode & tagged_node)448   FrameAndIter GetFrameAndIter(const TaggedNode& tagged_node) const {
449     return {tagged_node.input_frame->frame_id,
450             tagged_node.input_iter->iter_num};
451   }
452 
453   // Provide debugging output of the state of the executor.
454   void DumpState();
455 
456   // For debugging/logging only.
MaybeMarkStarted(const TaggedNode & tagged_node)457   void MaybeMarkStarted(const TaggedNode& tagged_node) {
458     // TODO(misard) Replace with a finer-grain enabling flag once we add better
459     // optional debugging support.
460     if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) {
461       mutex_lock l(tagged_node.input_frame->mu);
462       tagged_node.input_iter->mark_started(
463           immutable_state_.pending_ids()[tagged_node.node_item->node_id]);
464     }
465   }
466 
MaybeMarkCompleted(const TaggedNode & tagged_node)467   void MaybeMarkCompleted(const TaggedNode& tagged_node) {
468     // TODO(misard) Replace with a finer-grain enabling flag once we add better
469     // optional debugging support.
470     if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) {
471       mutex_lock l(tagged_node.input_frame->mu);
472       tagged_node.input_iter->mark_completed(
473           immutable_state_.pending_ids()[tagged_node.node_item->node_id]);
474     }
475   }
476 
477  private:
478   // Find an existing or create a new child frame in the frame 'frame' at
479   // iteration 'iter'.
480   void FindOrCreateChildFrame(FrameState* frame, IterationState* iter_state,
481                               const NodeItem& node_item, FrameState** child);
482 
483   // Delete a frame. Called when the frame is done.
484   void DeleteFrame(FrameState* frame, TaggedNodeSeq* ready);
485 
486   // Cleanup frames and iterations starting from frame/iter. Called when
487   // a child frame is done.
488   void CleanupFramesIterations(FrameState* frame, IterationState* iter_state,
489                                TaggedNodeSeq* ready);
490 
491   // Provide debugging output about an outstanding iteration in the executor.
492   void DumpIterationState(const FrameState* frame, IterationState* iteration);
493 
494   const ImmutableExecutorState& immutable_state_;
495   const int64 step_id_;
496   const bool vlog_;
497 
498   mutex mu_;
499 
500   // The root frame in which the execution of this step is started.
501   FrameState* root_frame_;
502 
503   // Mapping from frame ID to outstanding frames. A new frame is created
504   // at some iteration of an active frame. So the unique key for the new
505   // child frame is a hash composed of the ID of the parent frame, the iteration
506   // number at which the parent frame is creating the new frame, and the
507   // name of the new frame from nodedef.
508   absl::flat_hash_map<uint64, FrameState*> outstanding_frames_
509       TF_GUARDED_BY(mu_);
510 
511   TF_DISALLOW_COPY_AND_ASSIGN(PropagatorState);
512 };
513 
get_iter_num()514 inline int64 PropagatorState::TaggedNode::get_iter_num() const {
515   return input_iter->iter_num;
516 }
517 
518 }  // namespace tensorflow
519 
520 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_STATE_H_
521