• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/ring_reducer.h"
16 
17 #include <stdlib.h>
18 
19 #include <atomic>
20 #include <functional>
21 #include <utility>
22 
23 #include "tensorflow/core/common_runtime/collective_rma_local.h"
24 #include "tensorflow/core/common_runtime/collective_util.h"
25 #include "tensorflow/core/common_runtime/copy_tensor.h"
26 #include "tensorflow/core/common_runtime/device.h"
27 #include "tensorflow/core/common_runtime/device_mgr.h"
28 #include "tensorflow/core/common_runtime/dma_helper.h"
29 #include "tensorflow/core/common_runtime/process_util.h"
30 #include "tensorflow/core/framework/allocator.h"
31 #include "tensorflow/core/framework/device_base.h"
32 #include "tensorflow/core/framework/op_kernel.h"
33 #include "tensorflow/core/framework/tensor.h"
34 #include "tensorflow/core/framework/types.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/lib/core/notification.h"
37 #include "tensorflow/core/lib/core/status.h"
38 #include "tensorflow/core/lib/strings/str_util.h"
39 #include "tensorflow/core/lib/strings/strcat.h"
40 #include "tensorflow/core/platform/env.h"
41 #include "tensorflow/core/platform/types.h"
42 #include "tensorflow/core/profiler/lib/traceme.h"
43 
44 namespace tensorflow {
45 
~RingReducer()46 RingReducer::~RingReducer() { group_size_tensor_ready_.WaitForNotification(); }
47 
InitializeCollectiveParams(CollectiveParams * col_params)48 Status RingReducer::InitializeCollectiveParams(CollectiveParams* col_params) {
49   // TODO(b/113171733): change CHECKs to return errors.
50   CHECK_EQ(col_params->instance.type, REDUCTION_COLLECTIVE);
51   CHECK_EQ(col_params->instance.impl_details.collective_name, "RingReduce");
52   return RingAlg::InitializeCollectiveParams(col_params);
53 }
54 
Run(StatusCallback done)55 void RingReducer::Run(StatusCallback done) {
56   CHECK(col_ctx_);
57   CHECK(col_params_);
58   // Since `RingReducer` doesn't require non-overlapping collectives, unblock
59   // any collective that is blocked on this instance.
60   col_ctx_->col_exec->UnblockDependencies(*col_params_);
61 
62   done_ = std::move(done);
63   group_size_ = col_params_->group.group_size;
64   num_subdivs_ = static_cast<int>(
65       col_params_->instance.impl_details.subdiv_permutations.size());
66   CHECK_GT(num_subdivs_, 0);
67 
68   if (VLOG_IS_ON(1)) {
69     string buf;
70     for (int r = 0; r < col_params_->group.device_names.size(); ++r) {
71       strings::StrAppend(&buf, "dev ", r, " : ",
72                          col_params_->group.device_names[r], "\n");
73     }
74     for (int sd = 0;
75          sd < col_params_->instance.impl_details.subdiv_permutations.size();
76          ++sd) {
77       strings::StrAppend(&buf, "\nsubdiv ", sd, " perm: ");
78       for (auto x :
79            col_params_->instance.impl_details.subdiv_permutations[sd]) {
80         strings::StrAppend(&buf, x, ", ");
81       }
82     }
83     VLOG(1) << "RingReducer::Run for device " << col_ctx_->device_name
84             << " default_rank " << col_params_->default_rank << "\n"
85             << buf;
86   }
87 
88   // Start by copying input to output if they're not already the same, i.e. if
89   // we're not computing in-place on the input tensor.
90   if ((col_ctx_->input != col_ctx_->output) &&
91       (DMAHelper::base(col_ctx_->input) != DMAHelper::base(col_ctx_->output))) {
92     // We are running in a blockable thread and the callback can't block so
93     // just wait here on the copy.
94     Notification note;
95     Status status;
96     profiler::TraceMe activity("MemCpyAsync", profiler::TraceMeLevel::kInfo);
97     CollectiveRemoteAccessLocal::MemCpyAsync(
98         col_ctx_->op_ctx->op_device_context(),
99         col_ctx_->op_ctx->op_device_context(), col_ctx_->device,
100         col_ctx_->device, col_ctx_->op_ctx->input_alloc_attr(0),
101         col_ctx_->op_ctx->output_alloc_attr(0), col_ctx_->input,
102         col_ctx_->output, 0 /*dev_to_dev_stream_index*/,
103         [&note, &status](const Status& s) {
104           status.Update(s);
105           note.Notify();
106         });
107     note.WaitForNotification();
108     if (!status.ok()) {
109       done_(status);
110       return;
111     }
112   }
113   ContinueAfterInputCopy();
114 }
115 
116 // Note that this function is blocking and must not run in any thread
117 // which cannot be blocked.
ContinueAfterInputCopy()118 void RingReducer::ContinueAfterInputCopy() {
119   AllocatorAttributes attr = col_ctx_->op_ctx->output_alloc_attr(0);
120   ca_.reset(MakeCollectiveAdapter(col_ctx_->output, group_size_ * num_subdivs_,
121                                   col_ctx_->device->GetAllocator(attr)));
122 
123   if (col_params_->final_op) {
124     // Create an on-device scalar value from group_size_ that may be needed
125     // later.
126     // TODO(tucker): Cache and reuse across invocations? Or maybe the scalar
127     // can be provided to the kernel in host memory?
128     Tensor group_size_val = ca_->Scalar(group_size_);
129     if (col_params_->group.device_type != "CPU") {
130       uint64 safe_alloc_frontier = col_ctx_->device->SafeAllocFrontier(0);
131       AllocationAttributes aa;
132       std::function<uint64()> freed_by_func = [this, &safe_alloc_frontier]() {
133         safe_alloc_frontier =
134             col_ctx_->device->SafeAllocFrontier(safe_alloc_frontier);
135         return safe_alloc_frontier;
136       };
137       if (safe_alloc_frontier > 0) {
138         aa.freed_by_func = &freed_by_func;
139       }
140       group_size_tensor_ = ca_->Scalar(
141           col_ctx_->device->GetAllocator(col_ctx_->op_ctx->input_alloc_attr(0)),
142           aa);
143       DeviceContext* op_dev_ctx = col_ctx_->op_ctx->op_device_context();
144       op_dev_ctx->CopyCPUTensorToDevice(
145           &group_size_val, col_ctx_->device, &group_size_tensor_,
146           [this](const Status& s) {
147             if (!s.ok()) {
148               StartAbort(s);
149             }
150             group_size_tensor_ready_.Notify();
151           },
152           (safe_alloc_frontier == 0));
153     } else {
154       group_size_tensor_ = group_size_val;
155       group_size_tensor_ready_.Notify();
156     }
157   } else {
158     // Value won't be used, so no need to initialize.
159     group_size_tensor_ready_.Notify();
160   }
161   Finish(RunAsyncParts());
162 }
163 
InitRingField(RingField * rf,int chunk_idx,int subdiv_idx,int field_idx)164 void RingReducer::InitRingField(RingField* rf, int chunk_idx, int subdiv_idx,
165                                 int field_idx) {
166   RingAlg::InitRingField(rf, chunk_idx, subdiv_idx, field_idx);
167   if (rf->do_recv) {
168     rf->tmp_chunk = ca_->TempChunk(rf->sc_idx);
169   }
170 }
171 
172 // At the beginning of the algorithm initialize a RingField struct for
173 // every independent field of the tensor.
RunAsyncParts()174 bool RingReducer::RunAsyncParts() {
175   // This function orchestrates RingReduce actions on behalf of a
176   // single device. It is entered by a blockable thread that
177   // loops within it until all actions assigned to that device
178   // complete. Hence function local variables are accessible only by that
179   // one thread and do not require an explicit mutex.
180   rfv_.clear();
181   rfv_.resize(group_size_ * num_subdivs_);
182   PCQueue ready_queue;
183   for (int chunk_idx = 0; chunk_idx < group_size_; ++chunk_idx) {
184     for (int subdiv_idx = 0; subdiv_idx < num_subdivs_; ++subdiv_idx) {
185       int rf_index = (chunk_idx * num_subdivs_) + subdiv_idx;
186       InitRingField(&rfv_[rf_index], chunk_idx, subdiv_idx, rf_index);
187       ready_queue.Enqueue(&rfv_[rf_index]);
188     }
189   }
190   const DeviceBase::GpuDeviceInfo* gpu_info =
191       col_ctx_->device->tensorflow_gpu_device_info();
192   if (gpu_info) {
193     // Wait for all currently queued events on the CPU compute stream to
194     // complete before proceeding.  The previous InitRingField calls allocated
195     // temp memory buffers that are not guaranteed to be valid (e.g. for RDMA
196     // write) unless we do.
197     profiler::TraceMe activity("WaitForQueuedEvents",
198                                profiler::TraceMeLevel::kInfo);
199     Notification note;
200     Status s = gpu_info->default_context->ThenExecute(
201         col_ctx_->device, gpu_info->stream, [&note]() { note.Notify(); });
202     if (s.ok()) {
203       note.WaitForNotification();
204     } else {
205       mutex_lock l(status_mu_);
206       status_ =
207           errors::Internal("Failed to dispatch ThenExecute in RingReducer");
208       return false;
209     }
210   }
211 
212   int field_done_count = 0;
213   int send_pending_count = 0;
214   int recv_pending_count = 0;
215   std::atomic<bool> aborted(false);
216 
217   {
218     profiler::TraceMe activity("Loop", profiler::TraceMeLevel::kInfo);
219     // Loop until all RingFields have advanced to completion.
220     while (field_done_count < rfv_.size()) {
221       VLOG(4) << FieldState();
222       // Wait for a RingField to appear in the ready_queue.
223       RingField* rf = ready_queue.Dequeue();
224       // Advance the RingField to its next action and execute, repeating
225       // until either an async action has been started or the RingField
226       // is done.
227       bool dispatched = false;  // true if async action was initiated
228       do {
229         if (aborted) {
230           // Requeue this RingField to be counted off below.
231           ready_queue.Enqueue(rf);
232           break;
233         }
234         switch (rf->action) {
235           case RF_INIT:
236             if (rf->do_recv) {
237               rf->action = RF_RECV;
238               auto requeue = [this, rf, &ready_queue, &aborted](Status s) {
239                 if (!s.ok()) {
240                   aborted = true;
241                   StartAbort(s);
242                 }
243                 ready_queue.Enqueue(rf);
244               };
245               DispatchRecv(rf, requeue);
246               dispatched = true;
247               ++recv_pending_count;
248             } else {
249               rf->action = RF_SEND_READY;
250             }
251             break;
252           case RF_RECV:
253             CHECK_GT(recv_pending_count, 0);
254             --recv_pending_count;
255             if (!rf->second_pass) {
256               rf->action = RF_REDUCE;
257               Status s = collective_util::ComputeBinOp(
258                   col_ctx_->op_ctx, col_ctx_->op_params, col_ctx_->device,
259                   col_params_->merge_op, &rf->chunk, &rf->tmp_chunk);
260               if (!s.ok()) {
261                 aborted = true;
262                 StartAbort(s);
263               }
264             } else {
265               rf->action = RF_SEND_READY;
266             }
267             break;
268           case RF_REDUCE:
269             if (!rf->second_pass && col_params_->final_op && rf->is_final) {
270               rf->action = RF_FINALIZE;
271               group_size_tensor_ready_.WaitForNotification();
272               Status s = collective_util::ComputeBinOp(
273                   col_ctx_->op_ctx, col_ctx_->op_params, col_ctx_->device,
274                   col_params_->final_op, &rf->chunk, &group_size_tensor_);
275               if (!s.ok()) {
276                 aborted = true;
277                 StartAbort(s);
278               }
279             } else {
280               rf->action = RF_SEND_READY;
281             }
282             break;
283           case RF_FINALIZE:
284             rf->action = RF_DONE;
285             break;
286           case RF_SEND_READY:
287             if (rf->do_send) {
288               rf->action = RF_SEND;
289               auto send_complete = [this, rf, &ready_queue,
290                                     &aborted](Status s) {
291                 if (!s.ok()) {
292                   aborted = true;
293                   StartAbort(s);
294                 }
295                 ready_queue.Enqueue(rf);
296               };
297               DispatchSend(rf, send_complete);
298               dispatched = true;
299               ++send_pending_count;
300             } else {
301               rf->action = RF_DONE;
302             }
303             break;
304           case RF_SEND:
305             CHECK_GT(send_pending_count, 0);
306             --send_pending_count;
307             rf->action = RF_DONE;
308             break;
309           case RF_DONE:
310             break;
311         }
312         if (rf->action == RF_DONE) {
313           if (rf->second_pass) {
314             ++field_done_count;
315             break;  // from do while(!dispatched)
316           } else {
317             AdvanceToSecondPass(rf);
318           }
319         }
320       } while (!dispatched);
321       if (aborted) break;
322     }  // while (field_done_count < number of fields)
323 
324     if (aborted) {
325       // All of the pending data actions should be aborted; field the
326       // callbacks and clear the queue before quitting.
327       while ((send_pending_count > 0) || (recv_pending_count > 0)) {
328         RingField* rf = ready_queue.Dequeue();
329         switch (rf->action) {
330           case RF_RECV:
331             --recv_pending_count;
332             break;
333           case RF_SEND:
334             --send_pending_count;
335             break;
336           default: {
337           }  // Ignore any other actions
338         }
339       }
340     }
341   }
342 
343   CHECK_EQ(send_pending_count, 0);
344   CHECK_EQ(recv_pending_count, 0);
345 
346   VLOG(2) << this << " device=" << col_ctx_->device_name << " finish;"
347           << " final value " << TensorDebugString(ca_->Value());
348   return !aborted;
349 }
350 
351 namespace {
352 REGISTER_COLLECTIVE(RingReduce, RingReducer);
353 }  // namespace
354 
355 }  // namespace tensorflow
356