1 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PENDING_COUNTS_H_
2 #define TENSORFLOW_CORE_COMMON_RUNTIME_PENDING_COUNTS_H_
3
4 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
5
6 Licensed under the Apache License, Version 2.0 (the "License");
7 you may not use this file except in compliance with the License.
8 You may obtain a copy of the License at
9
10 http://www.apache.org/licenses/LICENSE-2.0
11
12 Unless required by applicable law or agreed to in writing, software
13 distributed under the License is distributed on an "AS IS" BASIS,
14 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 See the License for the specific language governing permissions and
16 limitations under the License.
17 ==============================================================================*/
18
19 #include <atomic>
20
21 #include "tensorflow/core/lib/gtl/flatmap.h"
22 #include "tensorflow/core/lib/hash/hash.h"
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/core/platform/macros.h"
25 #include "tensorflow/core/util/port.h"
26
27 namespace tensorflow {
28
29 // PendingCounts is an internal helper class to keep track of pending and
30 // dead counts for nodes, for use in the ExecutorState module. It
31 // holds a map from Handles to various counts for that handle. This
32 // information is needed per frame iteration. The amount of memory
33 // needed for an iteration is the same across all executions of the
34 // iteration. The memory amount and handles are precomputed at startup
35 // using a Layout object.
36 //
37 // PendingCounts::Layout layout;
38 // std::vector<PendingCounts::Handle> h(C);
39 // for (int id = 0; id < C; id++) {
40 // h[id] = r.AddHandle(max_pending[id], max_dead[id]);
41 // }
42 //
43 // When we actually want to start an iteration we first create a
44 // PendingCounts object and then index into it using the precomputed
45 // handles:
46
47 // PendingCounts counts(layout);
48 // ...
49 // counts.decrement_pending(h[id], 1);
50 class PendingCounts {
51 public:
52 // The state machine for a node's execution.
53 enum NodeState {
54 // The pending count for the node > 0.
55 PENDING_NOTREADY,
56 // The pending count for the node == 0, but the node has not
57 // started executing.
58 PENDING_READY,
59 // The node has started executing.
60 STARTED,
61 // The node has finished executing.
62 COMPLETED
63 };
64
65 // An opaque handle indicating where in the PendingCounts data structure
66 // the appropriate count information can be found.
67 class Handle;
68 // Given a node that needs to represent counts no larger than the
69 // specified "max_pending_count" and "max_dead_count", create a
70 // handle that can be passed to various PendingCounts routines
71 // to retrieve the count data for this node.
72 class Layout {
73 public:
74 Handle CreateHandle(size_t max_pending_count, size_t max_dead_count);
75
76 private:
77 friend class PendingCounts;
78 int next_offset_ = 0; // Next byte offset to allocate
79 };
80
81 // Create a new PendingCounts object that can hold the state of
82 // all the Handles allocated from "final_allocator".
PendingCounts(Layout layout)83 explicit PendingCounts(Layout layout)
84 : num_bytes_(layout.next_offset_), bytes_(new char[num_bytes_]) {
85 if (num_bytes_ >= sizeof(LargeCounts)) {
86 CHECK_EQ(uintptr_t(bytes_) % alignof(LargeCounts), 0);
87 }
88 }
89
90 // Create a new PendingCounts object with the same layout and counts
91 // as "other".
PendingCounts(const PendingCounts & other)92 explicit PendingCounts(const PendingCounts& other)
93 : num_bytes_(other.num_bytes_), bytes_(new char[num_bytes_]) {
94 if (num_bytes_ >= sizeof(LargeCounts)) {
95 CHECK_EQ(uintptr_t(bytes_) % alignof(LargeCounts), 0);
96 }
97 memcpy(bytes_, other.bytes_, other.num_bytes_);
98 }
99
~PendingCounts()100 ~PendingCounts() { delete[] bytes_; }
101
set_initial_count(Handle h,size_t pending_count)102 void set_initial_count(Handle h, size_t pending_count) {
103 if (h.is_large_) {
104 std::atomic<LargeCounts>* c_ptr = Large(h);
105 auto c = c_ptr->load(std::memory_order_relaxed);
106 c.pending = pending_count;
107 c.dead_count = 0;
108 c.has_started = 0;
109 c_ptr->store(c, std::memory_order_relaxed);
110 } else {
111 DCHECK_LE(pending_count, kMaxCountForPackedCounts);
112 std::atomic<PackedCounts>* c_ptr = Packed(h);
113 auto c = c_ptr->load(std::memory_order_relaxed);
114 c.pending = pending_count;
115 c.dead_count = 0;
116 c.has_started = 0;
117 c_ptr->store(c, std::memory_order_relaxed);
118 }
119 }
120
node_state(Handle h)121 NodeState node_state(Handle h) {
122 if (h.is_large_) {
123 return NodeStateForStruct(Large(h)->load(std::memory_order_relaxed));
124 } else {
125 return NodeStateForStruct(Packed(h)->load(std::memory_order_relaxed));
126 }
127 }
mark_started(Handle h)128 void mark_started(Handle h) {
129 DCHECK_EQ(pending(h), 0);
130 if (h.is_large_) {
131 std::atomic<LargeCounts>* c_ptr = Large(h);
132 auto c = c_ptr->load(std::memory_order_relaxed);
133 DCHECK_EQ(c.has_started, 0);
134 c.has_started = 1;
135 c_ptr->store(c, std::memory_order_relaxed);
136 } else {
137 std::atomic<PackedCounts>* c_ptr = Packed(h);
138 auto c = c_ptr->load(std::memory_order_relaxed);
139 DCHECK_EQ(c.has_started, 0);
140 c.has_started = 1;
141 c_ptr->store(c, std::memory_order_relaxed);
142 }
143 }
mark_completed(Handle h)144 void mark_completed(Handle h) {
145 if (h.is_large_) {
146 std::atomic<LargeCounts>* c_ptr = Large(h);
147 auto c = c_ptr->load(std::memory_order_relaxed);
148 DCHECK_EQ(c.has_started, 1);
149 c.pending = 1;
150 c_ptr->store(c, std::memory_order_relaxed);
151 } else {
152 std::atomic<PackedCounts>* c_ptr = Packed(h);
153 auto c = c_ptr->load(std::memory_order_relaxed);
154 DCHECK_EQ(c.has_started, 1);
155 c.pending = 1;
156 c_ptr->store(c, std::memory_order_relaxed);
157 }
158 }
pending(Handle h)159 int pending(Handle h) {
160 if (h.is_large_) {
161 LargeCounts c = Large(h)->load(std::memory_order_relaxed);
162 if (PENDING_NOTREADY == NodeStateForStruct(c)) {
163 return c.pending;
164 } else {
165 // The pending count encodes the state once the node has
166 // started, so just return 0.
167 return 0;
168 }
169 } else {
170 PackedCounts c = Packed(h)->load(std::memory_order_relaxed);
171 if (PENDING_NOTREADY == NodeStateForStruct(c)) {
172 return c.pending;
173 } else {
174 // The pending count encodes the state once the node has
175 // started, so just return 0.
176 return 0;
177 }
178 }
179 }
decrement_pending(Handle h,int v)180 int decrement_pending(Handle h, int v) {
181 DCHECK_GE(pending(h), v);
182 if (h.is_large_) {
183 std::atomic<LargeCounts>* c_ptr = Large(h);
184 auto c = c_ptr->load(std::memory_order_relaxed);
185 c.pending -= v;
186 c_ptr->store(c, std::memory_order_relaxed);
187 return c.pending;
188 } else {
189 std::atomic<PackedCounts>* c_ptr = Packed(h);
190 auto c = c_ptr->load(std::memory_order_relaxed);
191 c.pending -= v;
192 c_ptr->store(c, std::memory_order_relaxed);
193 return c.pending;
194 }
195 }
196 // Mark a merge node as live
197 // REQUIRES: Node corresponding to "h" is a merge node
mark_live(Handle h)198 void mark_live(Handle h) {
199 if (h.is_large_) {
200 std::atomic<LargeCounts>* c_ptr = Large(h);
201 auto c = c_ptr->load(std::memory_order_relaxed);
202 // Only do anything if the node hasn't already started executing.
203 if (PENDING_NOTREADY == NodeStateForStruct(c)) {
204 c.pending &= ~static_cast<int>(0x1);
205 c_ptr->store(c, std::memory_order_relaxed);
206 }
207 } else {
208 std::atomic<PackedCounts>* c_ptr = Packed(h);
209 auto c = c_ptr->load(std::memory_order_relaxed);
210 // Only do anything if the node hasn't already started executing.
211 if (PENDING_NOTREADY == NodeStateForStruct(c)) {
212 static_assert(7 == kMaxCountForPackedCounts,
213 "Live flag incorrect for max packed count");
214 c.pending &= 0x6;
215 c_ptr->store(c, std::memory_order_relaxed);
216 }
217 }
218 }
219
dead_count(Handle h)220 int dead_count(Handle h) {
221 int r = h.is_large_ ? Large(h)->load(std::memory_order_relaxed).dead_count
222 : Packed(h)->load(std::memory_order_relaxed).dead_count;
223 return r;
224 }
increment_dead_count(Handle h)225 void increment_dead_count(Handle h) {
226 if (h.is_large_) {
227 std::atomic<LargeCounts>* c_ptr = Large(h);
228 auto c = c_ptr->load(std::memory_order_relaxed);
229 if (PENDING_NOTREADY == NodeStateForStruct(c)) {
230 c.dead_count++;
231 c_ptr->store(c, std::memory_order_relaxed);
232 }
233 } else {
234 std::atomic<PackedCounts>* c_ptr = Packed(h);
235 auto c = c_ptr->load(std::memory_order_relaxed);
236 if (PENDING_NOTREADY == NodeStateForStruct(c)) {
237 DCHECK_LT(c.dead_count, kMaxCountForPackedCounts);
238 c.dead_count++;
239 c_ptr->store(c, std::memory_order_relaxed);
240 }
241 }
242 }
243
244 struct AdjustResult {
245 bool any_dead;
246 bool any_pending;
247
AdjustResultAdjustResult248 AdjustResult(bool any_dead, bool any_pending)
249 : any_dead(any_dead), any_pending(any_pending) {}
250 };
251
252 // A streamlined routine that does several pieces of bookkeeping at
253 // once. Equivalent to:
254 // if (increment_dead) increment_dead_count(h);
255 // decrement_pending(h, 1);
256 // return {dead_count(h) > 0, pending(h) > 0};
adjust_for_activation(Handle h,bool increment_dead)257 AdjustResult adjust_for_activation(Handle h, bool increment_dead) {
258 DCHECK_GE(pending(h), 1);
259 if (h.is_large_) {
260 return adjust_for_activation_shared(Large(h), increment_dead);
261 } else {
262 return adjust_for_activation_shared(Packed(h), increment_dead);
263 }
264 }
265
266 // The same as the above, but performs the operation atomically. This
267 // is thread-safe to run concurrently with other threads.
adjust_for_activation_atomic(Handle h,bool increment_dead)268 AdjustResult adjust_for_activation_atomic(Handle h, bool increment_dead) {
269 DCHECK_GE(pending(h), 1);
270 if (h.is_large_) {
271 return adjust_for_activation_shared_atomic(Large(h), increment_dead);
272 } else {
273 return adjust_for_activation_shared_atomic(Packed(h), increment_dead);
274 }
275 }
276
277 class Handle {
278 public:
Handle()279 Handle() : byte_offset_(0), is_large_(0) {}
280
281 private:
282 friend class PendingCounts;
283 int byte_offset_ : 31; // Byte offset of the rep in PendingCounts object
284 bool is_large_ : 1; // If true, rep is LargeCounts; otherwise PackedCounts
285 };
286
287 private:
288 template <typename T>
adjust_for_activation_shared(std::atomic<T> * c,bool increment_dead)289 inline AdjustResult adjust_for_activation_shared(std::atomic<T>* c,
290 bool increment_dead) {
291 T val = c->load(std::memory_order_relaxed);
292 if (increment_dead && PENDING_NOTREADY == NodeStateForStruct(val)) {
293 val.dead_count++;
294 }
295 val.pending--;
296 c->store(val, std::memory_order_relaxed);
297 return AdjustResult(val.dead_count, val.pending);
298 }
299
300 template <typename T>
adjust_for_activation_shared_atomic(std::atomic<T> * c,bool increment_dead)301 inline AdjustResult adjust_for_activation_shared_atomic(std::atomic<T>* c,
302 bool increment_dead) {
303 T old_val = c->load(std::memory_order_relaxed);
304 while (true) {
305 T new_val = old_val;
306 if (increment_dead && PENDING_NOTREADY == NodeStateForStruct(new_val)) {
307 new_val.dead_count++;
308 }
309 new_val.pending--;
310 AdjustResult ret(new_val.dead_count, new_val.pending);
311 if (TF_PREDICT_TRUE(c->compare_exchange_weak(old_val, new_val)))
312 return ret;
313 }
314 }
315
316 // We keep track of the pending count and dead input count for each
317 // graph node. The representation used here is designed to be cache
318 // efficient for graphs with large numbers of nodes, where most
319 // nodes have relatively small maximum pending counts (e.g. for one
320 // LSTM model, 99% of 5000+ nodes had in-degrees of 3 or less). We
321 // use one byte to hold both the pending and dead count for a node
322 // where these together can fit in one byte, and we use a hash table
323 // to handle the rare node ids that need larger counts than this.
324 // Each frame in this subgraph has its own PendingCounts.
325
326 // We use 3 bits each for dead_count and pending.
327 static constexpr int kMaxCountForPackedCounts = 7;
328
329 // Most counts are small, so we pack a pending count and a dead
330 // count into 3 bits each, use 1 bit to indicate that the node has
331 // started computing.
332 struct PackedCounts {
333 uint8 pending : 3;
334 uint8 dead_count : 3;
335 uint8 has_started : 1;
336 };
337
338 // NOTE: alignas(8) is critical to implement efficient atomic<LargeCounts>
339 // on MSVC.
340 struct alignas(8) LargeCounts {
341 uint32 pending;
342 uint32 dead_count : 31;
343 // NOTE(tlipcon): MSVC won't pack this struct into 8 bytes unless
344 // all of the member types are uint32.
345 uint32 has_started : 1;
346 };
347
348 template <typename T>
NodeStateForStruct(const T & c)349 NodeState NodeStateForStruct(const T& c) const {
350 if (c.has_started) {
351 return (c.pending == 0) ? STARTED : COMPLETED;
352 } else {
353 return (c.pending == 0) ? PENDING_READY : PENDING_NOTREADY;
354 }
355 }
Large(Handle h)356 inline std::atomic<LargeCounts>* Large(Handle h) {
357 DCHECK(h.is_large_);
358 DCHECK_LE(h.byte_offset_ + sizeof(std::atomic<LargeCounts>), num_bytes_);
359 DCHECK_EQ(h.byte_offset_ % alignof(std::atomic<LargeCounts>), 0);
360 return reinterpret_cast<std::atomic<LargeCounts>*>(bytes_ + h.byte_offset_);
361 }
Packed(Handle h)362 inline std::atomic<PackedCounts>* Packed(Handle h) {
363 DCHECK(!h.is_large_);
364 DCHECK_LE(h.byte_offset_ + sizeof(PackedCounts), num_bytes_);
365 return reinterpret_cast<std::atomic<PackedCounts>*>(bytes_ +
366 h.byte_offset_);
367 }
368
369 const int num_bytes_; // Just for bounds checking in debug mode
370 char* bytes_; // Array of num_bytes_ bytes
371
372 void operator=(const PendingCounts&) = delete;
373 };
374
CreateHandle(size_t max_pending_count,size_t max_dead_count)375 inline PendingCounts::Handle PendingCounts::Layout::CreateHandle(
376 size_t max_pending_count, size_t max_dead_count) {
377 Handle result;
378 if ((max_pending_count > kMaxCountForPackedCounts) ||
379 (max_dead_count > kMaxCountForPackedCounts)) {
380 constexpr int B = sizeof(std::atomic<LargeCounts>);
381 // Round byte offset to proper alignment
382 static_assert(
383 sizeof(std::atomic<LargeCounts>) >= alignof(std::atomic<LargeCounts>),
384 "std::atomic<LargeCounts> must be packed");
385 int64 offset = ((static_cast<int64>(next_offset_) + B - 1) / B) * B;
386 result.byte_offset_ = offset;
387 result.is_large_ = true;
388 next_offset_ = result.byte_offset_ + B;
389 } else {
390 result.byte_offset_ = next_offset_;
391 result.is_large_ = false;
392 static_assert(sizeof(std::atomic<PackedCounts>) == 1,
393 "std::atomic<PackedCounts> should be a single byte");
394 next_offset_ += sizeof(std::atomic<PackedCounts>);
395 }
396 return result;
397 }
398
399 } // end namespace tensorflow
400
401 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_PENDING_COUNTS_H_
402