1 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PENDING_COUNTS_H_
2 #define TENSORFLOW_CORE_COMMON_RUNTIME_PENDING_COUNTS_H_
3
4 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
5
6 Licensed under the Apache License, Version 2.0 (the "License");
7 you may not use this file except in compliance with the License.
8 You may obtain a copy of the License at
9
10 http://www.apache.org/licenses/LICENSE-2.0
11
12 Unless required by applicable law or agreed to in writing, software
13 distributed under the License is distributed on an "AS IS" BASIS,
14 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 See the License for the specific language governing permissions and
16 limitations under the License.
17 ==============================================================================*/
18
19 #include "tensorflow/core/lib/gtl/flatmap.h"
20 #include "tensorflow/core/lib/hash/hash.h"
21 #include "tensorflow/core/platform/logging.h"
22 #include "tensorflow/core/platform/macros.h"
23 #include "tensorflow/core/util/port.h"
24
25 namespace tensorflow {
26
27 // PendingCounts is an internal helper class to keep track of pending and
28 // dead counts for nodes, for use in the ExecutorState module. It
29 // holds a map from Handles to various counts for that handle. This
30 // information is needed per frame iteration. The amount of memory
31 // needed for an iteration is the same across all executions of the
32 // iteration. The memory amount and handles are precomputed at startup
33 // using a Layout object.
34 //
35 // PendingCounts::Layout layout;
36 // std::vector<PendingCounts::Handle> h(C);
37 // for (int id = 0; id < C; id++) {
38 // h[id] = r.AddHandle(max_pending[id], max_dead[id]);
39 // }
40 //
41 // When we actually want to start an iteration we first create a
42 // PendingCounts object and then index into it using the precomputed
43 // handles:
44
45 // PendingCounts counts(layout);
46 // ...
47 // counts.decrement_pending(h[id], 1);
48 class PendingCounts {
49 public:
50 // The state machine for a node's execution.
51 enum NodeState {
52 // The pending count for the node > 0.
53 PENDING_NOTREADY,
54 // The pending count for the node == 0, but the node has not
55 // started executing.
56 PENDING_READY,
57 // The node has started executing.
58 STARTED,
59 // The node has finished executing.
60 COMPLETED
61 };
62
63 // An opaque handle indicating where in the PendingCounts data structure
64 // the appropriate count information can be found.
65 class Handle;
66 // Given a node that needs to represent counts no larger than the
67 // specified "max_pending_count" and "max_dead_count", create a
68 // handle that can be passed to various PendingCounts routines
69 // to retrieve the count data for this node.
70 class Layout {
71 public:
72 Handle CreateHandle(size_t max_pending_count, size_t max_dead_count);
73
74 private:
75 friend class PendingCounts;
76 int next_offset_ = 0; // Next byte offset to allocate
77 };
78
79 // Create a new PendingCounts object that can hold the state of
80 // all the Handles allocated from "final_allocator".
PendingCounts(Layout layout)81 explicit PendingCounts(Layout layout)
82 : num_bytes_(layout.next_offset_), bytes_(new char[num_bytes_]) {}
83
84 // Create a new PendingCounts object with the same layout and counts
85 // as "other".
PendingCounts(const PendingCounts & other)86 explicit PendingCounts(const PendingCounts& other)
87 : num_bytes_(other.num_bytes_), bytes_(new char[num_bytes_]) {
88 CHECK_EQ(uintptr_t(bytes_) % alignof(LargeCounts), 0);
89 memcpy(bytes_, other.bytes_, other.num_bytes_);
90 }
91
~PendingCounts()92 ~PendingCounts() { delete[] bytes_; }
93
set_initial_count(Handle h,size_t pending_count)94 void set_initial_count(Handle h, size_t pending_count) {
95 if (h.is_large_) {
96 LargeCounts* c = Large(h);
97 c->pending = pending_count;
98 c->dead_count = 0;
99 c->has_started = 0;
100 } else {
101 PackedCounts* c = Packed(h);
102 DCHECK_LE(pending_count, kMaxCountForPackedCounts);
103 c->pending = pending_count;
104 c->dead_count = 0;
105 c->has_started = 0;
106 }
107 }
108
node_state(Handle h)109 NodeState node_state(Handle h) {
110 if (h.is_large_) {
111 return NodeStateForStruct(Large(h));
112 } else {
113 return NodeStateForStruct(Packed(h));
114 }
115 }
mark_started(Handle h)116 void mark_started(Handle h) {
117 DCHECK_EQ(pending(h), 0);
118 if (h.is_large_) {
119 LargeCounts* c = Large(h);
120 DCHECK_EQ(c->has_started, 0);
121 c->has_started = 1;
122 } else {
123 PackedCounts* c = Packed(h);
124 DCHECK_EQ(c->has_started, 0);
125 c->has_started = 1;
126 }
127 }
mark_completed(Handle h)128 void mark_completed(Handle h) {
129 if (h.is_large_) {
130 LargeCounts* c = Large(h);
131 DCHECK_EQ(c->has_started, 1);
132 c->pending = 1;
133 } else {
134 PackedCounts* c = Packed(h);
135 DCHECK_EQ(c->has_started, 1);
136 c->pending = 1;
137 }
138 }
pending(Handle h)139 int pending(Handle h) {
140 if (h.is_large_) {
141 LargeCounts* c = Large(h);
142 if (PENDING_NOTREADY == NodeStateForStruct(c)) {
143 return c->pending;
144 } else {
145 // The pending count encodes the state once the node has
146 // started, so just return 0.
147 return 0;
148 }
149 } else {
150 PackedCounts* c = Packed(h);
151 if (PENDING_NOTREADY == NodeStateForStruct(c)) {
152 return c->pending;
153 } else {
154 // The pending count encodes the state once the node has
155 // started, so just return 0.
156 return 0;
157 }
158 }
159 }
decrement_pending(Handle h,int v)160 int decrement_pending(Handle h, int v) {
161 DCHECK_GE(pending(h), v);
162 if (h.is_large_) {
163 LargeCounts* c = Large(h);
164 c->pending -= v;
165 return c->pending;
166 } else {
167 PackedCounts* c = Packed(h);
168 c->pending -= v;
169 return c->pending;
170 }
171 }
172 // Mark a merge node as live
173 // REQUIRES: Node corresponding to "h" is a merge node
mark_live(Handle h)174 void mark_live(Handle h) {
175 if (h.is_large_) {
176 LargeCounts* c = Large(h);
177 // Only do anything if the node hasn't already started executing.
178 if (PENDING_NOTREADY == NodeStateForStruct(c)) {
179 c->pending &= ~static_cast<int>(0x1);
180 }
181 } else {
182 PackedCounts* c = Packed(h);
183 // Only do anything if the node hasn't already started executing.
184 if (PENDING_NOTREADY == NodeStateForStruct(c)) {
185 static_assert(7 == kMaxCountForPackedCounts,
186 "Live flag incorrect for max packed count");
187 c->pending &= 0x6;
188 }
189 }
190 }
191
dead_count(Handle h)192 int dead_count(Handle h) {
193 int r = h.is_large_ ? Large(h)->dead_count : Packed(h)->dead_count;
194 return r;
195 }
increment_dead_count(Handle h)196 void increment_dead_count(Handle h) {
197 if (h.is_large_) {
198 LargeCounts* c = Large(h);
199 if (PENDING_NOTREADY == NodeStateForStruct(c)) {
200 c->dead_count++;
201 }
202 } else {
203 PackedCounts* c = Packed(h);
204 if (PENDING_NOTREADY == NodeStateForStruct(c)) {
205 DCHECK_LT(c->dead_count, kMaxCountForPackedCounts);
206 c->dead_count++;
207 }
208 }
209 }
210
211 // A streamlined routine that does several pieces of bookkeeping at
212 // once. Equivalent to:
213 // if (increment_dead) increment_dead_count(h);
214 // decrement_pending(h, 1);
215 // *pending_result = pending(h);
216 // *dead_result = dead_count(h);
adjust_for_activation(Handle h,bool increment_dead,int * pending_result,int * dead_result)217 void adjust_for_activation(Handle h, bool increment_dead, int* pending_result,
218 int* dead_result) {
219 DCHECK_GE(pending(h), 1);
220 if (h.is_large_) {
221 adjust_for_activation_shared(Large(h), increment_dead, pending_result,
222 dead_result);
223 } else {
224 adjust_for_activation_shared(Packed(h), increment_dead, pending_result,
225 dead_result);
226 }
227 }
228
229 class Handle {
230 public:
Handle()231 Handle() : byte_offset_(0), is_large_(0) {}
232
233 private:
234 friend class PendingCounts;
235 int byte_offset_ : 31; // Byte offset of the rep in PendingCounts object
236 bool is_large_ : 1; // If true, rep is LargeCounts; otherwise PackedCounts
237 };
238
239 private:
240 template <typename T>
adjust_for_activation_shared(T * c,bool increment_dead,int * pending_result,int * dead_result)241 inline void adjust_for_activation_shared(T* c, bool increment_dead,
242 int* pending_result,
243 int* dead_result) {
244 if (increment_dead) {
245 if (PENDING_NOTREADY == NodeStateForStruct(c)) {
246 c->dead_count++;
247 }
248 }
249 c->pending -= 1;
250 *dead_result = c->dead_count;
251 *pending_result = c->pending;
252 }
253
254 // We keep track of the pending count and dead input count for each
255 // graph node. The representation used here is designed to be cache
256 // efficient for graphs with large numbers of nodes, where most
257 // nodes have relatively small maximum pending counts (e.g. for one
258 // LSTM model, 99% of 5000+ nodes had in-degrees of 3 or less). We
259 // use one byte to hold both the pending and dead count for a node
260 // where these together can fit in one byte, and we use a hash table
261 // to handle the rare node ids that need larger counts than this.
262 // Each frame in this subgraph has its own PendingCounts.
263
264 // We use 3 bits each for dead_count and pending.
265 static const int kMaxCountForPackedCounts = 7;
266
267 // Most counts are small, so we pack a pending count and a dead
268 // count into 3 bits each, use 1 bit to indicate that the node has
269 // started computing.
270 struct PackedCounts {
271 uint8 pending : 3;
272 uint8 dead_count : 3;
273 uint8 has_started : 1;
274 };
275
276 struct LargeCounts {
277 uint32 pending;
278 uint32 dead_count : 31;
279 uint8 has_started : 1;
280 };
281
282 template <typename T>
NodeStateForStruct(T * c)283 NodeState NodeStateForStruct(T* c) const {
284 if (c->has_started) {
285 return (c->pending == 0) ? STARTED : COMPLETED;
286 } else {
287 return (c->pending == 0) ? PENDING_READY : PENDING_NOTREADY;
288 }
289 }
Large(Handle h)290 inline LargeCounts* Large(Handle h) {
291 DCHECK(h.is_large_);
292 DCHECK_LE(h.byte_offset_ + sizeof(LargeCounts), num_bytes_);
293 DCHECK_EQ(h.byte_offset_ % alignof(LargeCounts), 0);
294 return reinterpret_cast<LargeCounts*>(bytes_ + h.byte_offset_);
295 }
Packed(Handle h)296 inline PackedCounts* Packed(Handle h) {
297 DCHECK(!h.is_large_);
298 DCHECK_LE(h.byte_offset_ + sizeof(PackedCounts), num_bytes_);
299 return reinterpret_cast<PackedCounts*>(bytes_ + h.byte_offset_);
300 }
301
302 const int num_bytes_; // Just for bounds checking in debug mode
303 char* bytes_; // Array of num_bytes_ bytes
304
305 void operator=(const PendingCounts&) = delete;
306 };
307
CreateHandle(size_t max_pending_count,size_t max_dead_count)308 inline PendingCounts::Handle PendingCounts::Layout::CreateHandle(
309 size_t max_pending_count, size_t max_dead_count) {
310 Handle result;
311 if ((max_pending_count > kMaxCountForPackedCounts) ||
312 (max_dead_count > kMaxCountForPackedCounts)) {
313 int B = sizeof(LargeCounts);
314 // Round byte offset to proper alignment
315 DCHECK_GE(sizeof(LargeCounts), alignof(LargeCounts));
316 int64 offset = ((static_cast<int64>(next_offset_) + B - 1) / B) * B;
317 result.byte_offset_ = offset;
318 result.is_large_ = true;
319 next_offset_ = result.byte_offset_ + B;
320 } else {
321 result.byte_offset_ = next_offset_;
322 result.is_large_ = false;
323 DCHECK_EQ(sizeof(PackedCounts), 1);
324 next_offset_ += sizeof(PackedCounts);
325 }
326 return result;
327 }
328
329 } // end namespace tensorflow
330
331 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_PENDING_COUNTS_H_
332