• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PENDING_COUNTS_H_
2 #define TENSORFLOW_CORE_COMMON_RUNTIME_PENDING_COUNTS_H_
3 
4 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
5 
6 Licensed under the Apache License, Version 2.0 (the "License");
7 you may not use this file except in compliance with the License.
8 You may obtain a copy of the License at
9 
10     http://www.apache.org/licenses/LICENSE-2.0
11 
12 Unless required by applicable law or agreed to in writing, software
13 distributed under the License is distributed on an "AS IS" BASIS,
14 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 See the License for the specific language governing permissions and
16 limitations under the License.
17 ==============================================================================*/
18 
19 #include "tensorflow/core/lib/gtl/flatmap.h"
20 #include "tensorflow/core/lib/hash/hash.h"
21 #include "tensorflow/core/platform/logging.h"
22 #include "tensorflow/core/platform/macros.h"
23 #include "tensorflow/core/util/port.h"
24 
25 namespace tensorflow {
26 
27 // PendingCounts is an internal helper class to keep track of pending and
28 // dead counts for nodes, for use in the ExecutorState module.  It
29 // holds a map from Handles to various counts for that handle.  This
30 // information is needed per frame iteration. The amount of memory
31 // needed for an iteration is the same across all executions of the
32 // iteration. The memory amount and handles are precomputed at startup
33 // using a Layout object.
34 //
35 //    PendingCounts::Layout layout;
36 //    std::vector<PendingCounts::Handle> h(C);
37 //    for (int id = 0; id < C; id++) {
38 //      h[id] = r.AddHandle(max_pending[id], max_dead[id]);
39 //    }
40 //
41 // When we actually want to start an iteration we first create a
42 // PendingCounts object and then index into it using the precomputed
43 // handles:
44 
45 //    PendingCounts counts(layout);
46 //    ...
47 //    counts.decrement_pending(h[id], 1);
48 class PendingCounts {
49  public:
50   // The state machine for a node's execution.
51   enum NodeState {
52     // The pending count for the node > 0.
53     PENDING_NOTREADY,
54     // The pending count for the node == 0, but the node has not
55     // started executing.
56     PENDING_READY,
57     // The node has started executing.
58     STARTED,
59     // The node has finished executing.
60     COMPLETED
61   };
62 
63   // An opaque handle indicating where in the PendingCounts data structure
64   // the appropriate count information can be found.
65   class Handle;
66   // Given a node that needs to represent counts no larger than the
67   // specified "max_pending_count" and "max_dead_count", create a
68   // handle that can be passed to various PendingCounts routines
69   // to retrieve the count data for this node.
70   class Layout {
71    public:
72     Handle CreateHandle(size_t max_pending_count, size_t max_dead_count);
73 
74    private:
75     friend class PendingCounts;
76     int next_offset_ = 0;  // Next byte offset to allocate
77   };
78 
79   // Create a new PendingCounts object that can hold the state of
80   // all the Handles allocated from "final_allocator".
PendingCounts(Layout layout)81   explicit PendingCounts(Layout layout)
82       : num_bytes_(layout.next_offset_), bytes_(new char[num_bytes_]) {}
83 
84   // Create a new PendingCounts object with the same layout and counts
85   // as "other".
PendingCounts(const PendingCounts & other)86   explicit PendingCounts(const PendingCounts& other)
87       : num_bytes_(other.num_bytes_), bytes_(new char[num_bytes_]) {
88     CHECK_EQ(uintptr_t(bytes_) % alignof(LargeCounts), 0);
89     memcpy(bytes_, other.bytes_, other.num_bytes_);
90   }
91 
~PendingCounts()92   ~PendingCounts() { delete[] bytes_; }
93 
set_initial_count(Handle h,size_t pending_count)94   void set_initial_count(Handle h, size_t pending_count) {
95     if (h.is_large_) {
96       LargeCounts* c = Large(h);
97       c->pending = pending_count;
98       c->dead_count = 0;
99       c->has_started = 0;
100     } else {
101       PackedCounts* c = Packed(h);
102       DCHECK_LE(pending_count, kMaxCountForPackedCounts);
103       c->pending = pending_count;
104       c->dead_count = 0;
105       c->has_started = 0;
106     }
107   }
108 
node_state(Handle h)109   NodeState node_state(Handle h) {
110     if (h.is_large_) {
111       return NodeStateForStruct(Large(h));
112     } else {
113       return NodeStateForStruct(Packed(h));
114     }
115   }
mark_started(Handle h)116   void mark_started(Handle h) {
117     DCHECK_EQ(pending(h), 0);
118     if (h.is_large_) {
119       LargeCounts* c = Large(h);
120       DCHECK_EQ(c->has_started, 0);
121       c->has_started = 1;
122     } else {
123       PackedCounts* c = Packed(h);
124       DCHECK_EQ(c->has_started, 0);
125       c->has_started = 1;
126     }
127   }
mark_completed(Handle h)128   void mark_completed(Handle h) {
129     if (h.is_large_) {
130       LargeCounts* c = Large(h);
131       DCHECK_EQ(c->has_started, 1);
132       c->pending = 1;
133     } else {
134       PackedCounts* c = Packed(h);
135       DCHECK_EQ(c->has_started, 1);
136       c->pending = 1;
137     }
138   }
pending(Handle h)139   int pending(Handle h) {
140     if (h.is_large_) {
141       LargeCounts* c = Large(h);
142       if (PENDING_NOTREADY == NodeStateForStruct(c)) {
143         return c->pending;
144       } else {
145         // The pending count encodes the state once the node has
146         // started, so just return 0.
147         return 0;
148       }
149     } else {
150       PackedCounts* c = Packed(h);
151       if (PENDING_NOTREADY == NodeStateForStruct(c)) {
152         return c->pending;
153       } else {
154         // The pending count encodes the state once the node has
155         // started, so just return 0.
156         return 0;
157       }
158     }
159   }
decrement_pending(Handle h,int v)160   int decrement_pending(Handle h, int v) {
161     DCHECK_GE(pending(h), v);
162     if (h.is_large_) {
163       LargeCounts* c = Large(h);
164       c->pending -= v;
165       return c->pending;
166     } else {
167       PackedCounts* c = Packed(h);
168       c->pending -= v;
169       return c->pending;
170     }
171   }
172   // Mark a merge node as live
173   // REQUIRES: Node corresponding to "h" is a merge node
mark_live(Handle h)174   void mark_live(Handle h) {
175     if (h.is_large_) {
176       LargeCounts* c = Large(h);
177       // Only do anything if the node hasn't already started executing.
178       if (PENDING_NOTREADY == NodeStateForStruct(c)) {
179         c->pending &= ~static_cast<int>(0x1);
180       }
181     } else {
182       PackedCounts* c = Packed(h);
183       // Only do anything if the node hasn't already started executing.
184       if (PENDING_NOTREADY == NodeStateForStruct(c)) {
185         static_assert(7 == kMaxCountForPackedCounts,
186                       "Live flag incorrect for max packed count");
187         c->pending &= 0x6;
188       }
189     }
190   }
191 
dead_count(Handle h)192   int dead_count(Handle h) {
193     int r = h.is_large_ ? Large(h)->dead_count : Packed(h)->dead_count;
194     return r;
195   }
increment_dead_count(Handle h)196   void increment_dead_count(Handle h) {
197     if (h.is_large_) {
198       LargeCounts* c = Large(h);
199       if (PENDING_NOTREADY == NodeStateForStruct(c)) {
200         c->dead_count++;
201       }
202     } else {
203       PackedCounts* c = Packed(h);
204       if (PENDING_NOTREADY == NodeStateForStruct(c)) {
205         DCHECK_LT(c->dead_count, kMaxCountForPackedCounts);
206         c->dead_count++;
207       }
208     }
209   }
210 
211   // A streamlined routine that does several pieces of bookkeeping at
212   // once.  Equivalent to:
213   //    if (increment_dead) increment_dead_count(h);
214   //    decrement_pending(h, 1);
215   //    *pending_result = pending(h);
216   //    *dead_result = dead_count(h);
adjust_for_activation(Handle h,bool increment_dead,int * pending_result,int * dead_result)217   void adjust_for_activation(Handle h, bool increment_dead, int* pending_result,
218                              int* dead_result) {
219     DCHECK_GE(pending(h), 1);
220     if (h.is_large_) {
221       adjust_for_activation_shared(Large(h), increment_dead, pending_result,
222                                    dead_result);
223     } else {
224       adjust_for_activation_shared(Packed(h), increment_dead, pending_result,
225                                    dead_result);
226     }
227   }
228 
229   class Handle {
230    public:
Handle()231     Handle() : byte_offset_(0), is_large_(0) {}
232 
233    private:
234     friend class PendingCounts;
235     int byte_offset_ : 31;  // Byte offset of the rep in PendingCounts object
236     bool is_large_ : 1;  // If true, rep is LargeCounts; otherwise PackedCounts
237   };
238 
239  private:
240   template <typename T>
adjust_for_activation_shared(T * c,bool increment_dead,int * pending_result,int * dead_result)241   inline void adjust_for_activation_shared(T* c, bool increment_dead,
242                                            int* pending_result,
243                                            int* dead_result) {
244     if (increment_dead) {
245       if (PENDING_NOTREADY == NodeStateForStruct(c)) {
246         c->dead_count++;
247       }
248     }
249     c->pending -= 1;
250     *dead_result = c->dead_count;
251     *pending_result = c->pending;
252   }
253 
254   // We keep track of the pending count and dead input count for each
255   // graph node.  The representation used here is designed to be cache
256   // efficient for graphs with large numbers of nodes, where most
257   // nodes have relatively small maximum pending counts (e.g. for one
258   // LSTM model, 99% of 5000+ nodes had in-degrees of 3 or less).  We
259   // use one byte to hold both the pending and dead count for a node
260   // where these together can fit in one byte, and we use a hash table
261   // to handle the rare node ids that need larger counts than this.
262   // Each frame in this subgraph has its own PendingCounts.
263 
264   // We use 3 bits each for dead_count and pending.
265   static const int kMaxCountForPackedCounts = 7;
266 
267   // Most counts are small, so we pack a pending count and a dead
268   // count into 3 bits each, use 1 bit to indicate that the node has
269   // started computing.
270   struct PackedCounts {
271     uint8 pending : 3;
272     uint8 dead_count : 3;
273     uint8 has_started : 1;
274   };
275 
276   struct LargeCounts {
277     uint32 pending;
278     uint32 dead_count : 31;
279     uint8 has_started : 1;
280   };
281 
282   template <typename T>
NodeStateForStruct(T * c)283   NodeState NodeStateForStruct(T* c) const {
284     if (c->has_started) {
285       return (c->pending == 0) ? STARTED : COMPLETED;
286     } else {
287       return (c->pending == 0) ? PENDING_READY : PENDING_NOTREADY;
288     }
289   }
Large(Handle h)290   inline LargeCounts* Large(Handle h) {
291     DCHECK(h.is_large_);
292     DCHECK_LE(h.byte_offset_ + sizeof(LargeCounts), num_bytes_);
293     DCHECK_EQ(h.byte_offset_ % alignof(LargeCounts), 0);
294     return reinterpret_cast<LargeCounts*>(bytes_ + h.byte_offset_);
295   }
Packed(Handle h)296   inline PackedCounts* Packed(Handle h) {
297     DCHECK(!h.is_large_);
298     DCHECK_LE(h.byte_offset_ + sizeof(PackedCounts), num_bytes_);
299     return reinterpret_cast<PackedCounts*>(bytes_ + h.byte_offset_);
300   }
301 
302   const int num_bytes_;  // Just for bounds checking in debug mode
303   char* bytes_;          // Array of num_bytes_ bytes
304 
305   void operator=(const PendingCounts&) = delete;
306 };
307 
CreateHandle(size_t max_pending_count,size_t max_dead_count)308 inline PendingCounts::Handle PendingCounts::Layout::CreateHandle(
309     size_t max_pending_count, size_t max_dead_count) {
310   Handle result;
311   if ((max_pending_count > kMaxCountForPackedCounts) ||
312       (max_dead_count > kMaxCountForPackedCounts)) {
313     int B = sizeof(LargeCounts);
314     // Round byte offset to proper alignment
315     DCHECK_GE(sizeof(LargeCounts), alignof(LargeCounts));
316     int64 offset = ((static_cast<int64>(next_offset_) + B - 1) / B) * B;
317     result.byte_offset_ = offset;
318     result.is_large_ = true;
319     next_offset_ = result.byte_offset_ + B;
320   } else {
321     result.byte_offset_ = next_offset_;
322     result.is_large_ = false;
323     DCHECK_EQ(sizeof(PackedCounts), 1);
324     next_offset_ += sizeof(PackedCounts);
325   }
326   return result;
327 }
328 
329 }  // end namespace tensorflow
330 
331 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_PENDING_COUNTS_H_
332