1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/propagator_state.h"
17
18 #include "tensorflow/core/common_runtime/graph_view.h"
19 #include "tensorflow/core/common_runtime/immutable_executor_state.h"
20 #include "tensorflow/core/common_runtime/propagator_debug_utils.h"
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/lib/hash/hash.h"
23 #include "tensorflow/core/platform/hash.h"
24 #include "tensorflow/core/profiler/lib/traceme.h"
25
26 namespace tensorflow {
27
PropagatorState(const ImmutableExecutorState & immutable_state,int64 step_id,bool vlog)28 PropagatorState::PropagatorState(const ImmutableExecutorState& immutable_state,
29 int64 step_id, bool vlog)
30 : immutable_state_(immutable_state),
31 step_id_(step_id),
32 vlog_(vlog || VLOG_IS_ON(1)) {
33 // We start the entire execution in iteration 0 of the root frame
34 // so let us create the root frame and the state for iteration 0.
35 // We assume root_frame_->frame_name.empty().
36 root_frame_ = new FrameState(immutable_state_, 1);
37 root_frame_->frame_id = 0; // must be 0
38 root_frame_->InitializeFrameInfo(immutable_state_.get_root_frame_info());
39
40 // Initialize iteration 0.
41 root_frame_->SetIteration(
42 0, new PropagatorState::IterationState(0, root_frame_->pending_counts,
43 root_frame_->total_input_tensors));
44
45 outstanding_frames_.emplace(root_frame_->frame_id, root_frame_);
46 }
47
~PropagatorState()48 PropagatorState::~PropagatorState() {
49 for (auto name_frame : outstanding_frames_) {
50 delete name_frame.second;
51 }
52 }
53
ActivateRoots(gtl::ArraySlice<const NodeItem * > roots,TaggedNodeSeq * ready)54 void PropagatorState::ActivateRoots(gtl::ArraySlice<const NodeItem*> roots,
55 TaggedNodeSeq* ready) {
56 mutex_lock l(root_frame_->mu);
57 IterationState* root_iter = root_frame_->GetIteration(0);
58 for (const NodeItem* item : roots) {
59 DCHECK_EQ(item->num_inputs, 0);
60 ready->emplace_back(item, root_frame_, root_iter, false);
61 }
62 root_iter->outstanding_ops = ready->size();
63 }
64
PropagateOutputs(const TaggedNode & tagged_node,EntryVector * outputs,TaggedNodeSeq * ready)65 void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node,
66 EntryVector* outputs,
67 TaggedNodeSeq* ready) {
68 profiler::TraceMe activity(
69 [&]() {
70 return strings::StrCat(
71 "ExecutorPropagateOutputs#", "id=", step_id_,
72 ",kernel_name=", tagged_node.node_item->kernel->name_view(),
73 ",num_output_edges=", tagged_node.node_item->num_output_edges,
74 ",num_output_control_edges=",
75 tagged_node.node_item->num_output_control_edges, "#");
76 },
77 profiler::GetTFTraceMeLevel(/*is_expensive=*/false));
78
79 const NodeItem* const item = tagged_node.node_item;
80 FrameState* const input_frame = tagged_node.input_frame;
81 IterationState* const input_iter = tagged_node.input_iter;
82 const bool is_dead = tagged_node.is_dead;
83
84 // Propagates outputs along out edges, and puts newly ready nodes
85 // into the ready queue.
86 DCHECK(ready->empty());
87 bool is_frame_done = false;
88 FrameState* output_frame = input_frame;
89 IterationState* output_iter = input_iter;
90
91 if (!item->is_enter_exit_or_next_iter) {
92 // Fast path for node types that don't need special handling.
93 // This is the case for most nodes.
94 DCHECK_EQ(input_frame, output_frame);
95 FrameState* frame = input_frame;
96 is_frame_done = frame->ActivateNodesAndAdjustOutstanding(
97 item, is_dead, output_iter, outputs, ready);
98 } else if (item->is_enter) {
99 FindOrCreateChildFrame(input_frame, input_iter, *item, &output_frame);
100 {
101 mutex_lock l(output_frame->mu);
102 output_iter = output_frame->GetIteration(0);
103 if (item->is_constant_enter) {
104 // Propagate to all active iterations if this is a loop invariant.
105 output_frame->AddLoopInv(item, (*outputs)[0], ready);
106 } else {
107 int activated = output_frame->ActivateNodesLocked(
108 item, is_dead, output_iter, outputs, ready);
109 output_frame->AdjustOutstandingOpsLocked(output_iter, activated, ready);
110 }
111 output_frame->num_pending_inputs--;
112 }
113 is_frame_done = input_frame->DecrementOutstandingOps(input_iter, ready);
114 } else if (item->is_exit) {
115 if (is_dead) {
116 mutex_lock l(input_frame->mu);
117 // Stop and remember this node if it is a dead exit.
118 if (input_iter->iter_num == input_frame->iteration_count) {
119 input_frame->dead_exits.push_back(item);
120 }
121 is_frame_done =
122 input_frame->DecrementOutstandingOpsLocked(input_iter, ready);
123 } else {
124 output_frame = input_frame->parent_frame;
125 output_iter = input_frame->parent_iter;
126 {
127 mutex_lock l(output_frame->mu);
128 int activated = output_frame->ActivateNodesLocked(
129 item, is_dead, output_iter, outputs, ready);
130 output_frame->AdjustOutstandingOpsLocked(output_iter, activated, ready);
131 }
132 is_frame_done = input_frame->DecrementOutstandingOps(input_iter, ready);
133 }
134 } else {
135 DCHECK(item->is_next_iteration);
136 mutex_lock l(input_frame->mu);
137 if (is_dead) {
138 // Stop the deadness propagation.
139 output_frame = nullptr;
140 } else {
141 if (input_iter->iter_num == input_frame->iteration_count &&
142 input_frame->num_outstanding_iterations ==
143 input_frame->max_parallel_iterations) {
144 // Reached the maximum for parallel iterations.
145 input_frame->next_iter_roots.push_back({item, (*outputs)[0]});
146 output_frame = nullptr;
147 } else {
148 // If this is a new iteration, start it.
149 if (input_iter->iter_num == input_frame->iteration_count) {
150 output_iter = input_frame->IncrementIteration(ready);
151 } else {
152 output_iter = input_frame->GetIteration(input_iter->iter_num + 1);
153 }
154 }
155 }
156 if (output_frame != nullptr) {
157 // This is the case when node is not Enter, Exit, or NextIteration.
158 DCHECK(input_frame == output_frame);
159 int activated = output_frame->ActivateNodesLocked(
160 item, is_dead, output_iter, outputs, ready);
161 output_frame->AdjustOutstandingOpsLocked(output_iter, activated, ready);
162 }
163 is_frame_done =
164 input_frame->DecrementOutstandingOpsLocked(input_iter, ready);
165 }
166
167 // At this point, this node is completely done. We also know if the
168 // completion of this node makes its frame completed.
169 if (is_frame_done) {
170 FrameState* parent_frame = input_frame->parent_frame;
171 IterationState* parent_iter = input_frame->parent_iter;
172 DeleteFrame(input_frame, ready);
173 if (parent_frame != nullptr) {
174 // The completion of frame may cause completions in its parent frame.
175 // So clean things up recursively.
176 CleanupFramesIterations(parent_frame, parent_iter, ready);
177 }
178 }
179 }
180
DumpIterationState(const FrameState * frame,IterationState * iteration)181 void PropagatorState::DumpIterationState(const FrameState* frame,
182 IterationState* iteration) {
183 const std::vector<const NodeItem*>* nodes = frame->nodes;
184 // Dump any waiting nodes that are holding on to tensors.
185 for (const NodeItem* node : *nodes) {
186 PendingCounts::Handle pending_id =
187 immutable_state_.pending_ids()[node->node_id];
188 if (iteration->node_state(pending_id) == PendingCounts::PENDING_NOTREADY ||
189 iteration->node_state(pending_id) == PendingCounts::PENDING_READY) {
190 DumpPendingNodeState(*node, iteration->input_tensors, false);
191 }
192 }
193 // Then the active nodes.
194 for (const NodeItem* node : *nodes) {
195 PendingCounts::Handle pending_id =
196 immutable_state_.pending_ids()[node->node_id];
197 if (iteration->node_state(pending_id) == PendingCounts::STARTED) {
198 DumpActiveNodeState(*node, iteration->input_tensors);
199 }
200 }
201 // Show all input tensors in use.
202 const int total_input_tensors = frame->total_input_tensors;
203 size_t total_bytes = 0;
204 for (int i = 0; i < total_input_tensors; ++i) {
205 const Entry& input = iteration->input_tensors[i];
206 const Tensor* tensor = GetTensorValueForDump(input);
207 if (tensor->IsInitialized()) {
208 LOG(WARNING) << " Input " << i << ": "
209 << strings::StrCat(
210 "Tensor<type: ", DataTypeString(tensor->dtype()),
211 " shape: ", tensor->shape().DebugString(),
212 ", bytes: ", tensor->TotalBytes(), ">");
213 total_bytes += tensor->TotalBytes();
214 }
215 }
216 LOG(WARNING) << " Total bytes " << total_bytes;
217 }
218
DumpState()219 void PropagatorState::DumpState() {
220 mutex_lock l(mu_);
221 LOG(WARNING) << "Dumping state";
222 for (auto& frame : outstanding_frames_) {
223 LOG(WARNING) << frame.first;
224 FrameState* frame_state = frame.second;
225 frame_state->DumpIterationState(this);
226 }
227 }
228
FindOrCreateChildFrame(FrameState * frame,IterationState * iter_state,const NodeItem & node_item,FrameState ** child)229 void PropagatorState::FindOrCreateChildFrame(FrameState* frame,
230 IterationState* iter_state,
231 const NodeItem& node_item,
232 FrameState** child) {
233 // Get the child frame name.
234 const ImmutableExecutorState::FrameInfo& frame_info =
235 immutable_state_.get_enter_frame_info(node_item);
236
237 const uint64 child_id = Hash64Combine(
238 frame->frame_id,
239 Hash64Combine(iter_state->iter_num, Hash64(frame_info.name)));
240
241 {
242 tf_shared_lock executor_lock(mu_);
243 auto it = outstanding_frames_.find(child_id);
244 if (it != outstanding_frames_.end()) {
245 *child = it->second;
246 return;
247 }
248 }
249
250 // Need to create a new frame instance.
251 // Note that this new frame instance is created without any locks.
252 if (vlog_) {
253 const string child_name = strings::StrCat(
254 frame->frame_name, ";", iter_state->iter_num, ";", frame_info.name);
255 VLOG(2) << "Create frame: " << child_name << " id: " << child_id;
256 }
257
258 FrameState* temp =
259 new FrameState(immutable_state_, frame_info.parallel_iterations);
260 temp->frame_id = child_id;
261 temp->parent_frame = frame;
262 temp->parent_iter = iter_state;
263 temp->InitializeFrameInfo(frame_info);
264
265 // Initialize iteration 0.
266 {
267 mutex_lock l(temp->mu);
268 temp->SetIteration(0, new IterationState(0, temp->pending_counts,
269 temp->total_input_tensors));
270 }
271
272 {
273 mutex_lock executor_lock(mu_);
274 auto it = outstanding_frames_.find(child_id);
275 if (it != outstanding_frames_.end()) {
276 *child = it->second;
277 } else {
278 mutex_lock frame_lock(frame->mu);
279 iter_state->outstanding_frame_count++;
280 outstanding_frames_[child_id] = temp;
281 *child = temp;
282 temp = nullptr;
283 }
284 }
285 delete temp; // Not used so delete it.
286 }
287
DeleteFrame(FrameState * frame,TaggedNodeSeq * ready)288 void PropagatorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) {
289 // First, propagate dead_exits (if any) to the parent frame.
290 FrameState* parent_frame = frame->parent_frame;
291 IterationState* parent_iter_state = frame->parent_iter;
292 if (parent_frame != nullptr) {
293 mutex_lock parent_frame_lock(parent_frame->mu);
294 // Propagate all the dead exits to the parent frame.
295 mutex_lock this_frame_lock(frame->mu);
296
297 for (const NodeItem* item : frame->dead_exits) {
298 auto maybe_add_to_ready = [&](const NodeItem& dst_item, bool dst_ready,
299 bool dst_dead) {
300 if (dst_ready) {
301 if (dst_item.is_control_trigger) dst_dead = false;
302 ready->emplace_back(&dst_item, parent_frame, parent_iter_state,
303 dst_dead);
304 parent_iter_state->outstanding_ops++;
305 }
306 };
307
308 auto propagate_to_non_merge = [&](PendingCounts::Handle dst_pending_id) {
309 parent_iter_state->increment_dead_count(dst_pending_id);
310 return parent_iter_state->decrement_pending(dst_pending_id, 1) == 0;
311 };
312
313 for (const EdgeInfo& e : item->output_edges()) {
314 const NodeItem& dst_item =
315 immutable_state_.graph_view().node_ref(e.dst_id);
316 const auto dst_pending_id = immutable_state_.pending_ids()[e.dst_id];
317
318 bool dst_dead = true;
319 bool dst_ready;
320 // We know this is a dead input to dst.
321 if (dst_item.is_merge) {
322 parent_iter_state->increment_dead_count(dst_pending_id);
323 const int dead_cnt = parent_iter_state->dead_count(dst_pending_id);
324 dst_dead = (dead_cnt == dst_item.num_inputs);
325 dst_ready =
326 (parent_iter_state->pending(dst_pending_id) == 1) && dst_dead;
327 } else {
328 dst_ready = propagate_to_non_merge(dst_pending_id);
329 }
330 maybe_add_to_ready(dst_item, dst_ready, dst_dead);
331 }
332
333 for (const ControlEdgeInfo& e : item->output_control_edges()) {
334 const NodeItem& dst_item =
335 immutable_state_.graph_view().node_ref(e.dst_id);
336 const auto dst_pending_id = immutable_state_.pending_ids()[e.dst_id];
337
338 bool dst_dead;
339 bool dst_ready;
340 // We know this is a dead input to dst.
341 if (dst_item.is_merge) {
342 parent_iter_state->decrement_pending(dst_pending_id, 2);
343 int count = parent_iter_state->pending(dst_pending_id);
344 int dead_cnt = parent_iter_state->dead_count(dst_pending_id);
345 dst_dead = (dead_cnt == dst_item.num_inputs);
346 dst_ready = (count == 0) || ((count == 1) && dst_dead);
347 } else {
348 dst_dead = true;
349 dst_ready = propagate_to_non_merge(dst_pending_id);
350 }
351 maybe_add_to_ready(dst_item, dst_ready, dst_dead);
352 }
353 }
354 }
355
356 // Delete the frame.
357 if (vlog_) VLOG(2) << "Delete frame " << frame->frame_id;
358 {
359 mutex_lock executor_lock(mu_);
360 outstanding_frames_.erase(frame->frame_id);
361 }
362 delete frame;
363 }
364
CleanupFramesIterations(FrameState * frame,IterationState * iter_state,TaggedNodeSeq * ready)365 void PropagatorState::CleanupFramesIterations(FrameState* frame,
366 IterationState* iter_state,
367 TaggedNodeSeq* ready) {
368 bool is_frame_done = false;
369 {
370 mutex_lock frame_lock(frame->mu);
371 iter_state->outstanding_frame_count--;
372 is_frame_done = frame->CleanupIterations(iter_state, ready);
373 }
374 if (is_frame_done) {
375 FrameState* parent_frame = frame->parent_frame;
376 IterationState* parent_iter = frame->parent_iter;
377 DeleteFrame(frame, ready);
378 if (parent_frame != nullptr) {
379 // The completion of frame may cause completions in its parent frame.
380 // So clean things up recursively.
381 CleanupFramesIterations(parent_frame, parent_iter, ready);
382 }
383 }
384 }
385
386 template <bool atomic>
ActivateNodesFastPathInternal(const NodeItem * item,const bool is_dead,IterationState * iter_state,EntryVector * outputs,TaggedNodeSeq * ready)387 int PropagatorState::FrameState::ActivateNodesFastPathInternal(
388 const NodeItem* item, const bool is_dead, IterationState* iter_state,
389 EntryVector* outputs, TaggedNodeSeq* ready) {
390 // If we know that none of the item's edge destinations require special
391 // handling (i.e. none of the nodes is a merge or control trigger node), we
392 // can take a fast path that avoids accessing the destination NodeItem.
393 const GraphView& gview = immutable_state.graph_view();
394 int new_outstanding = 0;
395
396 // Add dst to the ready queue if it's ready
397 //
398 // NOTE(mrry): Use a macro here instead of a lambda, because this method is
399 // performance-critical and we need to ensure that the code is inlined.
400 #define MAYBE_ADD_TO_READY(dst_id, adjust_result) \
401 do { \
402 if (!adjust_result.any_pending) { \
403 const NodeItem* dst_item = &gview.node_ref(dst_id); \
404 TaggedNode& t = ready->emplace_back(); \
405 t.node_item = dst_item; \
406 t.input_frame = this; \
407 t.input_iter = iter_state; \
408 t.is_dead = adjust_result.any_dead; \
409 new_outstanding++; \
410 } \
411 } while (0);
412
413 Entry* input_tensors = iter_state->input_tensors;
414 for (const EdgeInfo& e : item->output_edges()) {
415 const int dst_id = e.dst_id;
416 const PendingCounts::Handle dst_pending_id =
417 immutable_state.pending_ids()[dst_id];
418 const int src_slot = e.output_slot;
419
420 const bool increment_dead =
421 (is_dead || ((*outputs)[src_slot].state == Entry::State::NO_VALUE));
422 const int dst_loc = e.input_slot;
423 if (e.is_last) {
424 input_tensors[dst_loc] = std::move((*outputs)[src_slot]);
425 } else {
426 input_tensors[dst_loc] = (*outputs)[src_slot];
427 }
428 const PendingCounts::AdjustResult adjust_result =
429 atomic
430 ? iter_state->adjust_for_activation_atomic(dst_pending_id,
431 increment_dead)
432 : iter_state->adjust_for_activation(dst_pending_id, increment_dead);
433 MAYBE_ADD_TO_READY(dst_id, adjust_result);
434 }
435
436 for (const ControlEdgeInfo& e : item->output_control_edges()) {
437 const int dst_id = e.dst_id;
438 const PendingCounts::Handle dst_pending_id =
439 immutable_state.pending_ids()[dst_id];
440 const PendingCounts::AdjustResult adjust_result =
441 atomic
442 ? iter_state->adjust_for_activation_atomic(dst_pending_id, is_dead)
443 : iter_state->adjust_for_activation(dst_pending_id, is_dead);
444 MAYBE_ADD_TO_READY(dst_id, adjust_result);
445 }
446
447 return new_outstanding;
448 #undef MAYBE_ADD_TO_READY
449 }
450
ActivateNodesSlowPath(const NodeItem * item,const bool is_dead,IterationState * iter_state,EntryVector * outputs,TaggedNodeSeq * ready)451 int PropagatorState::FrameState::ActivateNodesSlowPath(
452 const NodeItem* item, const bool is_dead, IterationState* iter_state,
453 EntryVector* outputs, TaggedNodeSeq* ready) {
454 // If any of the edge destinations is a merge or a control trigger node,
455 // we need to read each destination NodeItem to determine what action
456 // to take.
457 const GraphView& gview = immutable_state.graph_view();
458 int activated = 0;
459 auto maybe_add_to_ready = [&](int dst_id, const NodeItem* dst_item,
460 bool dst_ready, bool dst_dead) {
461 // Add dst to the ready queue if it's ready
462 if (dst_ready) {
463 if (dst_item->is_control_trigger) dst_dead = false;
464 ready->emplace_back(dst_item, this, iter_state, dst_dead);
465 activated++;
466 }
467 };
468
469 Entry* input_tensors = iter_state->input_tensors;
470
471 for (const EdgeInfo& e : item->output_edges()) {
472 const int dst_id = e.dst_id;
473 const NodeItem* dst_item = &gview.node_ref(dst_id);
474 const PendingCounts::Handle dst_pending_id =
475 immutable_state.pending_ids()[dst_id];
476 const int src_slot = e.output_slot;
477
478 bool dst_dead = false;
479 bool dst_ready = false;
480 bool dst_need_input = true;
481
482 if (dst_item->is_merge) {
483 // A merge node is ready if all control inputs have arrived and either
484 // a) a live data input becomes available or b) all data inputs are
485 // dead. For Merge, pending's LSB is set iff a live data input has
486 // arrived.
487 if ((*outputs)[src_slot].state != Entry::State::NO_VALUE) {
488 // This is a live data input.
489 int count = iter_state->pending(dst_pending_id);
490 iter_state->mark_live(dst_pending_id);
491 // Only the first live edge sets the input and (potentially)
492 // triggers execution. The low bit of count is set if and
493 // only if no live input has been used yet (mark_live clears
494 // it). The node should be started if and only if this is
495 // the first live input and there are no pending control
496 // edges, i.e. count == 1.
497 dst_ready = (count == 1);
498 dst_need_input = ((count & 0x1) == 1);
499 } else {
500 // This is a dead data input. Note that dst_node is dead if node is
501 // a dead enter. We need this to handle properly a while loop on
502 // the untaken branch of a conditional.
503 // TODO(yuanbyu): This is a bit hacky, but a good solution for
504 // now.
505 iter_state->increment_dead_count(dst_pending_id);
506 const int dead_cnt = iter_state->dead_count(dst_pending_id);
507 dst_dead = (dead_cnt == dst_item->num_inputs) || item->is_enter;
508 dst_ready = (iter_state->pending(dst_pending_id) == 1) && dst_dead;
509 dst_need_input = false;
510 }
511 } else {
512 // Handle all other (non-merge) nodes.
513 const bool increment_dead =
514 (is_dead || ((*outputs)[src_slot].state == Entry::State::NO_VALUE));
515 const PendingCounts::AdjustResult adjust_result =
516 iter_state->adjust_for_activation(dst_pending_id, increment_dead);
517 dst_dead = adjust_result.any_dead;
518 dst_ready = !adjust_result.any_pending;
519 }
520
521 if (dst_need_input) {
522 const int dst_loc = e.input_slot;
523 if (e.is_last) {
524 input_tensors[dst_loc] = std::move((*outputs)[src_slot]);
525 } else {
526 input_tensors[dst_loc] = (*outputs)[src_slot];
527 }
528 }
529
530 maybe_add_to_ready(dst_id, dst_item, dst_ready, dst_dead);
531 }
532
533 for (const ControlEdgeInfo& e : item->output_control_edges()) {
534 const int dst_id = e.dst_id;
535 const NodeItem* dst_item = &gview.node_ref(dst_id);
536 const PendingCounts::Handle dst_pending_id =
537 immutable_state.pending_ids()[dst_id];
538
539 bool dst_dead;
540 bool dst_ready;
541 if (dst_item->is_merge) {
542 // A merge node is ready if all control inputs have arrived and either
543 // a) a live data input becomes available or b) all data inputs are
544 // dead. For Merge, pending's LSB is set iff a live data input has
545 // arrived.
546 iter_state->decrement_pending(dst_pending_id, 2);
547 int count = iter_state->pending(dst_pending_id);
548 int dead_cnt = iter_state->dead_count(dst_pending_id);
549 dst_dead = (dead_cnt == dst_item->num_inputs);
550 dst_ready = (count == 0) || ((count == 1) && dst_dead);
551 } else {
552 // Handle all other (non-merge) nodes.
553 const PendingCounts::AdjustResult adjust_result =
554 iter_state->adjust_for_activation(dst_pending_id, is_dead);
555 dst_dead = adjust_result.any_dead;
556 dst_ready = !adjust_result.any_pending;
557 }
558 maybe_add_to_ready(dst_id, dst_item, dst_ready, dst_dead);
559 }
560
561 return activated;
562 }
563
ActivateNodesAndAdjustOutstanding(const NodeItem * item,const bool is_dead,IterationState * iter_state,EntryVector * outputs,TaggedNodeSeq * ready)564 bool PropagatorState::FrameState::ActivateNodesAndAdjustOutstanding(
565 const NodeItem* item, const bool is_dead, IterationState* iter_state,
566 EntryVector* outputs, TaggedNodeSeq* ready) {
567 if (TF_PREDICT_FALSE(item->is_any_consumer_merge_or_control_trigger)) {
568 mutex_lock l(mu);
569 int activated =
570 ActivateNodesSlowPath(item, is_dead, iter_state, outputs, ready);
571 return AdjustOutstandingOpsLocked(iter_state, activated - 1, ready);
572 }
573 {
574 tf_shared_lock l(mu);
575 int activated =
576 ActivateNodesFastPathShared(item, is_dead, iter_state, outputs, ready);
577 bool iter_done = AdjustOutstandingOpsFastPath(iter_state, activated - 1);
578 if (!iter_done) return false;
579 }
580 mutex_lock l(mu);
581 return CleanupIterations(iter_state, ready);
582 }
583
ActivateNodesLocked(const NodeItem * item,const bool is_dead,IterationState * iter_state,EntryVector * outputs,TaggedNodeSeq * ready)584 int PropagatorState::FrameState::ActivateNodesLocked(const NodeItem* item,
585 const bool is_dead,
586 IterationState* iter_state,
587 EntryVector* outputs,
588 TaggedNodeSeq* ready) {
589 if (TF_PREDICT_FALSE(item->is_any_consumer_merge_or_control_trigger)) {
590 return ActivateNodesSlowPath(item, is_dead, iter_state, outputs, ready);
591 } else {
592 return ActivateNodesFastPathLocked(item, is_dead, iter_state, outputs,
593 ready);
594 }
595 }
596
ActivateNexts(IterationState * iter_state,TaggedNodeSeq * ready)597 void PropagatorState::FrameState::ActivateNexts(IterationState* iter_state,
598 TaggedNodeSeq* ready) {
599 int activated = 0;
600 // Propagate the deferred NextIteration nodes to the new iteration.
601 for (auto& node_entry : next_iter_roots) {
602 const NodeItem* item = node_entry.first;
603 const Entry& entry = node_entry.second;
604 const bool is_dead = entry.state == Entry::State::NO_VALUE;
605 EntryVector outputs{entry};
606 activated +=
607 ActivateNodesLocked(item, is_dead, iter_state, &outputs, ready);
608 }
609 next_iter_roots.clear();
610 AdjustOutstandingOpsLocked(iter_state, activated, ready);
611 }
612
ActivateLoopInvs(IterationState * iter_state,TaggedNodeSeq * ready)613 void PropagatorState::FrameState::ActivateLoopInvs(IterationState* iter_state,
614 TaggedNodeSeq* ready) {
615 // Propagate loop invariants to the new iteration.
616 int activated = 0;
617 for (auto& node_entry : inv_values) {
618 const NodeItem* item = node_entry.first;
619 const Entry& entry = node_entry.second;
620 const bool is_dead = entry.state == Entry::State::NO_VALUE;
621 EntryVector outputs{entry};
622 activated +=
623 ActivateNodesLocked(item, is_dead, iter_state, &outputs, ready);
624 }
625 AdjustOutstandingOpsLocked(iter_state, activated, ready);
626 }
627
AddLoopInv(const NodeItem * item,const Entry & entry,TaggedNodeSeq * ready)628 void PropagatorState::FrameState::AddLoopInv(const NodeItem* item,
629 const Entry& entry,
630 TaggedNodeSeq* ready) {
631 // Store this value.
632 inv_values.push_back({item, entry});
633
634 // Make this value available to all iterations.
635 const bool is_dead = entry.state == Entry::State::NO_VALUE;
636 for (int i = 0; i <= iteration_count; ++i) {
637 EntryVector outputs{entry};
638 IterationState* iter_state = GetIteration(i);
639 int activated =
640 ActivateNodesLocked(item, is_dead, iter_state, &outputs, ready);
641 AdjustOutstandingOpsLocked(iter_state, activated, ready);
642 }
643 }
644
IsIterationDone(IterationState * iter_state)645 bool PropagatorState::FrameState::IsIterationDone(IterationState* iter_state) {
646 if (iter_state->outstanding_ops == 0 &&
647 iter_state->outstanding_frame_count == 0) {
648 if (iter_state->iter_num == 0) {
649 // The enclosing frame has no pending input.
650 return num_pending_inputs == 0;
651 } else {
652 // The preceding iteration is deleted (and therefore done).
653 return (GetIteration(iter_state->iter_num - 1) == nullptr);
654 }
655 }
656 return false;
657 }
658
659 PropagatorState::IterationState*
IncrementIteration(TaggedNodeSeq * ready)660 PropagatorState::FrameState::IncrementIteration(TaggedNodeSeq* ready) {
661 iteration_count++;
662
663 // Initialize the next iteration.
664 IterationState* next_iter =
665 new IterationState(iteration_count, pending_counts, total_input_tensors);
666 SetIteration(iteration_count, next_iter);
667 num_outstanding_iterations++;
668 dead_exits.clear();
669
670 // Activate the successors of the deferred roots in the new iteration.
671 ActivateNexts(next_iter, ready);
672
673 // Activate the loop invariants in the new iteration.
674 ActivateLoopInvs(next_iter, ready);
675
676 return next_iter;
677 }
678
CleanupIterations(IterationState * iter_state,TaggedNodeSeq * ready)679 bool PropagatorState::FrameState::CleanupIterations(IterationState* iter_state,
680 TaggedNodeSeq* ready) {
681 int64 curr_iter = iter_state->iter_num;
682 while (curr_iter <= iteration_count && IsIterationDone(iter_state)) {
683 delete iter_state;
684 SetIteration(curr_iter, nullptr);
685 --num_outstanding_iterations;
686 ++curr_iter;
687
688 // When one iteration is completed, we check for deferred iteration,
689 // and start it if there is one.
690 if (!next_iter_roots.empty()) {
691 IncrementIteration(ready);
692 }
693
694 if (curr_iter <= iteration_count) {
695 iter_state = GetIteration(curr_iter);
696 }
697 }
698 return IsFrameDone();
699 }
700
InitializeFrameInfo(const ImmutableExecutorState::FrameInfo & finfo)701 void PropagatorState::FrameState::InitializeFrameInfo(
702 const ImmutableExecutorState::FrameInfo& finfo) {
703 pending_counts = finfo.pending_counts.get();
704 total_input_tensors = finfo.total_inputs;
705 num_pending_inputs = finfo.input_count;
706 nodes = finfo.nodes.get();
707 }
708
SetIteration(int64 iter,IterationState * state)709 void PropagatorState::FrameState::SetIteration(int64 iter,
710 IterationState* state)
711 TF_EXCLUSIVE_LOCKS_REQUIRED(mu) {
712 size_t index = iter % (max_parallel_iterations + 1);
713 DCHECK(state == nullptr || iterations[index] == nullptr);
714 iterations_raw[index] = state;
715 if (index == 0) {
716 iterations_first = state;
717 }
718 }
719
720 // Decrement the outstanding op count and clean up the iterations in the
721 // frame. Return true iff the execution of the frame is done.
DecrementOutstandingOps(IterationState * iter_state,TaggedNodeSeq * ready)722 bool PropagatorState::FrameState::DecrementOutstandingOps(
723 IterationState* iter_state, TaggedNodeSeq* ready) {
724 return AdjustOutstandingOps(iter_state, -1, ready);
725 }
726
AdjustOutstandingOps(IterationState * iter_state,int delta,TaggedNodeSeq * ready)727 bool PropagatorState::FrameState::AdjustOutstandingOps(
728 IterationState* iter_state, int delta, TaggedNodeSeq* ready) {
729 // Given the following profile of values of 'delta' for wide_deep model from
730 // the TF model garden:
731 //
732 // Count Value
733 // ---------------
734 // 757938 delta=0x0
735 // 541713 delta=0xffffffff
736 // 138115 delta=0x1
737 // 58770 delta=0x2
738 // 5394 delta=0x3
739 // 4669 delta=0x4
740 // 2037 delta=0xa
741 // 1646 delta=0x7
742 // 1632 delta=0x6
743 // 1613 delta=0x6c
744 // 1224 delta=0x5
745 // 409 delta=0x53
746 // 17 delta=0x86
747 //
748 // ... it's worth no-opping out when delta == 0 to avoid the atomic
749 // instruction.
750 if (delta == 0) {
751 return false;
752 }
753 {
754 tf_shared_lock sl(mu);
755 if (TF_PREDICT_TRUE(!AdjustOutstandingOpsFastPath(iter_state, delta))) {
756 return false;
757 }
758 }
759 mutex_lock l(mu);
760 DCHECK(IsIterationDone(iter_state));
761 return CleanupIterations(iter_state, ready);
762 }
763
AdjustOutstandingOpsFastPath(IterationState * iter_state,int delta)764 bool PropagatorState::FrameState::AdjustOutstandingOpsFastPath(
765 IterationState* iter_state, int delta) {
766 auto old_val = iter_state->outstanding_ops.fetch_add(delta);
767 return (old_val + delta == 0) && IsIterationDone(iter_state);
768 }
769
770 // Decrement the outstanding op count and clean up the iterations in the
771 // frame. Return true iff the execution of the frame is done.
DecrementOutstandingOpsLocked(IterationState * iter_state,TaggedNodeSeq * ready)772 bool PropagatorState::FrameState::DecrementOutstandingOpsLocked(
773 IterationState* iter_state, TaggedNodeSeq* ready)
774 TF_EXCLUSIVE_LOCKS_REQUIRED(mu) {
775 return AdjustOutstandingOpsLocked(iter_state, -1, ready);
776 }
777
AdjustOutstandingOpsLocked(IterationState * iter_state,int delta,TaggedNodeSeq * ready)778 bool PropagatorState::FrameState::AdjustOutstandingOpsLocked(
779 IterationState* iter_state, int delta, TaggedNodeSeq* ready) {
780 // We hold the lock, so we don't need to use an atomic modification.
781 auto cur_val = iter_state->outstanding_ops.load(std::memory_order_relaxed);
782 DCHECK(delta >= 0 || cur_val >= -delta)
783 << "cannot adjust outstanding_ops by " << delta
784 << " when current value is " << cur_val;
785 auto new_val = cur_val + delta;
786 iter_state->outstanding_ops.store(new_val, std::memory_order_relaxed);
787 if (new_val != 0) {
788 return false;
789 }
790 return CleanupIterations(iter_state, ready);
791 }
792
793 // Returns true if the computation in the frame is completed.
IsFrameDone()794 bool PropagatorState::FrameState::IsFrameDone()
795 TF_EXCLUSIVE_LOCKS_REQUIRED(mu) {
796 return (num_pending_inputs == 0 && num_outstanding_iterations == 0);
797 }
798
799 } // namespace tensorflow
800