• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/ring_alg.h"
16 
17 #include <stdlib.h>
18 
19 #include <atomic>
20 #include <functional>
21 #include <utility>
22 
23 #include "tensorflow/core/common_runtime/collective_rma_local.h"
24 #include "tensorflow/core/common_runtime/collective_util.h"
25 #include "tensorflow/core/common_runtime/copy_tensor.h"
26 #include "tensorflow/core/common_runtime/device.h"
27 #include "tensorflow/core/common_runtime/device_mgr.h"
28 #include "tensorflow/core/common_runtime/dma_helper.h"
29 #include "tensorflow/core/common_runtime/process_util.h"
30 #include "tensorflow/core/framework/allocator.h"
31 #include "tensorflow/core/framework/device_base.h"
32 #include "tensorflow/core/framework/op_kernel.h"
33 #include "tensorflow/core/framework/tensor.h"
34 #include "tensorflow/core/framework/types.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/lib/core/notification.h"
37 #include "tensorflow/core/lib/core/status.h"
38 #include "tensorflow/core/lib/strings/str_util.h"
39 #include "tensorflow/core/lib/strings/strcat.h"
40 #include "tensorflow/core/platform/env.h"
41 #include "tensorflow/core/platform/types.h"
42 
43 // Set true for greater intelligibility of debug mode log messages.
44 #define READABLE_KEYS false
45 // A ring algorithm exchanges chunks of tensor between devices.  The chunk size
46 // depends on the number of subdivisions specified in the algorithm.  If the
47 // user does not specify the number of subdivisions we may infer the number
48 // dynamically so that the resulting chunk size does not exceed
49 // kMaxChunkSizeBytes, empirically set at 4 MiB.
50 constexpr size_t kMaxChunkSizeBytes = (4 * 1024 * 1024);
51 // kMaxSubdivsPerDev is used to give an upper bound on the number of
52 // subdivisions dynamically generated.  A reasonable value would be a small
53 // multiple of the number of NICs adjacent to each device.
54 constexpr int kMaxSubdivsPerDevice = 2;
55 
56 namespace tensorflow {
57 namespace {
58 // Each CollectiveOp implementation is free to define its own
59 // BufRendezvous key format.  This function produces the key used by
60 // RingAlg instances.  Note that the exec_key will differentiate between
61 // different instances consequently we don't need to further differentiate
62 // between subclasses of RingAlg.
RingAlgBufKey(const string & name,const string & exec_key,int pass,int section,int source_rank)63 string RingAlgBufKey(const string& name, const string& exec_key, int pass,
64                      int section, int source_rank) {
65   if (READABLE_KEYS) {
66     return strings::StrCat(name, "(", exec_key, "):pass(", pass, "):section(",
67                            section, "):srcrank(", source_rank, ")");
68   } else {
69     // TODO(b/78352018): Try out some kind of denser encoding, e.g. 128 bit
70     // hash.
71     return strings::StrCat(exec_key, ":", pass, ":", section, ":", source_rank);
72   }
73 }
74 
75 }  // namespace
76 
Enqueue(RingField * rf)77 void RingAlg::PCQueue::Enqueue(RingField* rf) {
78   mutex_lock l(pcq_mu_);
79   deque_.push_back(rf);
80   if (waiter_count_ > 0) {
81     cv_.notify_one();
82   }
83 }
84 
Dequeue()85 RingAlg::RingField* RingAlg::PCQueue::Dequeue() {
86   mutex_lock l(pcq_mu_);
87   if (deque_.empty()) {
88     ++waiter_count_;
89     while (deque_.empty()) {
90       cv_.wait(l);
91     }
92     --waiter_count_;
93   }
94   RingField* rf = deque_.front();
95   deque_.pop_front();
96   return rf;
97 }
98 
RingAlg(CollectiveType type,const string & name)99 RingAlg::RingAlg(CollectiveType type, const string& name)
100     : type_(type),
101       name_(name),
102       col_ctx_(nullptr),
103       col_params_(nullptr),
104       done_(nullptr),
105       group_size_(-1),
106       num_subdivs_(-1) {}
107 
108 namespace {
GenerateSubdivsInCollectiveParams(CollectiveParams * col_params)109 Status GenerateSubdivsInCollectiveParams(CollectiveParams* col_params) {
110   if (col_params->instance.shape.num_elements() == 0) {
111     return errors::Internal("shape in CollectiveParams should be non-empty");
112   }
113   const int kAvgDevPerTask =
114       col_params->group.group_size / col_params->group.num_tasks;
115   const int kMaxNumSubdivs = kMaxSubdivsPerDevice * kAvgDevPerTask;
116   if (kMaxNumSubdivs <= 0) {
117     return errors::Internal("Unexpected kMaxNumSubdivs ", kMaxNumSubdivs,
118                             " in ",
119                             col_params->instance.impl_details.collective_name);
120   }
121   // NOTE(ayushd): If no subdiv_offsets have been specified, dynamically add
122   // as many offsets as needed so that the size of tensor chunks <=
123   // kMaxChunkSizeBytes.  Empirically, chunks that are too small or too large
124   // lead to worse performance.
125   int num_subdivs = 0;
126   const size_t tensor_size = col_params->instance.shape.num_elements() *
127                              DataTypeSize(col_params->instance.data_type);
128   size_t chunk_size;
129   do {
130     ++num_subdivs;
131     int num_chunks = col_params->group.group_size * num_subdivs;
132     chunk_size = tensor_size / num_chunks;
133     VLOG(2) << "num_subdivs " << num_subdivs << " num_chunks " << num_chunks
134             << " chunk_size " << chunk_size;
135   } while (chunk_size > kMaxChunkSizeBytes && num_subdivs < kMaxNumSubdivs);
136   if (num_subdivs <= 0) {
137     return errors::Internal("Unexpected num_subdivs ", num_subdivs, " in ",
138                             col_params->instance.impl_details.collective_name);
139   }
140 
141   int subdiv_stride = kAvgDevPerTask / num_subdivs;
142   if (subdiv_stride == 0) subdiv_stride = 1;
143   col_params->instance.impl_details.subdiv_offsets.reserve(num_subdivs);
144   for (int sdi = 0; sdi < num_subdivs; ++sdi) {
145     int subdiv_offset = subdiv_stride * sdi;
146     if (sdi % 2 == 1) subdiv_offset *= -1;
147     col_params->instance.impl_details.subdiv_offsets.push_back(subdiv_offset);
148   }
149 
150   if (VLOG_IS_ON(2)) {
151     string subdiv_buf;
152     for (const int subdiv_offset :
153          col_params->instance.impl_details.subdiv_offsets) {
154       strings::StrAppend(&subdiv_buf, " ", subdiv_offset);
155     }
156     VLOG(2) << "Dynamically generated " << num_subdivs
157             << " subdiv_offsets:" << subdiv_buf << " tensor_size "
158             << tensor_size << " chunk_size " << chunk_size;
159   }
160 
161   return Status::OK();
162 }
163 }  // namespace
164 
InitializeCollectiveParams(CollectiveParams * col_params)165 Status RingAlg::InitializeCollectiveParams(CollectiveParams* col_params) {
166   const string& device_name =
167       col_params->group.device_names[col_params->default_rank];
168   // Each subdiv permutation is a ring formed by rotating each
169   // single-task subsequence of devices by an offset.  This makes most
170   // sense when each task has the same number of devices but we can't
171   // depend on that being the case so we'll compute something that
172   // works in any case.
173 
174   // Start by counting the devices in each task.
175   // Precondition: device_names must be sorted so that all devices in
176   // the same task are adjacent.
177   VLOG(2) << "Sorted task names: "
178           << absl::StrJoin(col_params->group.task_names, ", ");
179   std::vector<int> dev_per_task;
180   const string* prior_task_name = &col_params->group.task_names[0];
181   int dev_count = 1;
182   for (int di = 1; di < col_params->group.group_size; ++di) {
183     if (col_params->group.task_names[di] != *prior_task_name) {
184       dev_per_task.push_back(dev_count);
185       dev_count = 1;
186       prior_task_name = &col_params->group.task_names[di];
187     } else {
188       ++dev_count;
189     }
190   }
191   dev_per_task.push_back(dev_count);
192   DCHECK_EQ(col_params->group.num_tasks, dev_per_task.size());
193 
194   if (col_params->instance.impl_details.subdiv_offsets.empty()) {
195     TF_RETURN_IF_ERROR(GenerateSubdivsInCollectiveParams(col_params));
196   }
197 
198   // Generate a ring permutation for requested offset.
199   VLOG(2) << "Setting up perms for col_params " << col_params
200           << " subdiv_permutations "
201           << &col_params->instance.impl_details.subdiv_permutations;
202   col_params->instance.impl_details.subdiv_permutations.resize(
203       col_params->instance.impl_details.subdiv_offsets.size());
204   col_params->subdiv_rank.resize(
205       col_params->instance.impl_details.subdiv_offsets.size(), -1);
206   for (int sdi = 0;
207        sdi < col_params->instance.impl_details.subdiv_offsets.size(); ++sdi) {
208     std::vector<int>& perm =
209         col_params->instance.impl_details.subdiv_permutations[sdi];
210     DCHECK_EQ(perm.size(), 0);
211     int offset = col_params->instance.impl_details.subdiv_offsets[sdi];
212     // A negative subdivision offset is interpreted as follows:
213     //  1. Reverse the local device ordering.
214     //  2. Begin the subdivision at abs(offset) in the reversed ordering.
215     bool reverse = false;
216     if (offset < 0) {
217       offset = abs(offset);
218       reverse = true;
219     }
220     int prior_dev_count = 0;  // sum over prior worker device counts
221     for (int ti = 0; ti < col_params->group.num_tasks; ++ti) {
222       for (int di = 0; di < dev_per_task[ti]; ++di) {
223         int di_offset = (di + offset) % dev_per_task[ti];
224         int offset_di =
225             reverse ? (dev_per_task[ti] - (di_offset + 1)) : di_offset;
226         // Device index in global subdivision permutation.
227         int permuted_di = prior_dev_count + offset_di;
228         int rank = static_cast<int>(perm.size());
229         perm.push_back(permuted_di);
230         if (col_params->group.device_names[permuted_di] == device_name) {
231           DCHECK_EQ(permuted_di, col_params->default_rank);
232           col_params->subdiv_rank[sdi] = rank;
233         }
234       }
235       prior_dev_count += dev_per_task[ti];
236     }
237     DCHECK_EQ(col_params->group.group_size, perm.size());
238   }
239 
240   VLOG(2) << collective_util::SubdivPermDebugString(*col_params);
241   return Status::OK();
242 }
243 
InitializeCollectiveContext(std::shared_ptr<CollectiveContext> col_ctx)244 Status RingAlg::InitializeCollectiveContext(
245     std::shared_ptr<CollectiveContext> col_ctx) {
246   DCHECK(col_ctx->dev_mgr);
247   col_ctx_ = col_ctx;
248   col_params_ = col_ctx->col_params;
249   return collective_util::InitializeDeviceAndLocality(
250       col_ctx->dev_mgr, col_ctx->device_name, &col_ctx->device,
251       &col_ctx->device_locality);
252 }
253 
TensorDebugString(const Tensor & tensor)254 string RingAlg::TensorDebugString(const Tensor& tensor) {
255   const DeviceBase::GpuDeviceInfo* gpu_device_info =
256       col_ctx_->op_ctx->device()->tensorflow_gpu_device_info();
257   if (gpu_device_info) {
258     Tensor cpu_tensor(tensor.dtype(), tensor.shape());
259     Status st = gpu_device_info->default_context->CopyDeviceTensorToCPUSync(
260         &tensor, "" /*tensor_name*/, col_ctx_->device, &cpu_tensor);
261     DCHECK(st.ok());
262     return cpu_tensor.SummarizeValue(64);
263   } else {
264     return tensor.SummarizeValue(64);
265   }
266 }
267 
StartAbort(const Status & s)268 void RingAlg::StartAbort(const Status& s) {
269   // In abort mode we stop issuing additional ProvideBuf
270   // and ConsumeBuf calls, but we need to wait for all of the
271   // outstanding callbacks to be invoked before quitting.
272   bool abort_started = false;
273   {
274     mutex_lock l(status_mu_);
275     if (status_.ok()) {
276       LOG(ERROR) << "Aborting Ring" << name_ << " with " << s;
277       abort_started = true;
278       status_.Update(s);
279     }
280   }
281   // If this is the initial entry to abort mode and it's not a cancellation,
282   // then invoke StartAbort on the CollectiveExecutor that invoked us.  That
283   // should start cancellation on all of the outstanding CollectiveRemoteAccess
284   // actions. If it's cancellation all pending send/recv should be cancelled as
285   // well and there's then no need to abort.
286   if (abort_started) {
287     if (col_ctx_->op_ctx->cancellation_manager() == nullptr ||
288         (!col_ctx_->op_ctx->cancellation_manager()->IsCancelled() &&
289          !col_ctx_->op_ctx->cancellation_manager()->IsCancelling())) {
290       col_ctx_->col_exec->StartAbort(s);
291     }
292   }
293 }
294 
Finish(bool ok)295 void RingAlg::Finish(bool ok) {
296   if (ok) {
297     // Recover the output from the adaptor.
298     ca_->ConsumeFinalValue(col_ctx_->output);
299   }
300   Status s;
301   {
302     mutex_lock l(status_mu_);
303     s = status_;
304   }
305   rfv_.clear();  // Give up Refs on output tensor.
306   done_(s);
307 }
308 
309 // At the beginning of the algorithm initialize a RingField struct for
310 // every independent field of the tensor.
InitRingField(RingField * rf,int chunk_idx,int subdiv_idx,int field_idx)311 void RingAlg::InitRingField(RingField* rf, int chunk_idx, int subdiv_idx,
312                             int field_idx) {
313   // Note on field indexing: There are group_size_ devices in the
314   // instance, implying the same number of chunks per tensor, where a
315   // chunk is the unit of data transferred in a time step.  However, if
316   // a device can simultaneously send data by 2 or more independent
317   // channels we can speed up the transfer by subdividing chunks and
318   // processing multiple subdivisions at once.  So the actual number
319   // of RingFields is group_size_ * num_subdivs_.
320   DCHECK_EQ(field_idx, (chunk_idx * num_subdivs_) + subdiv_idx);
321   rf->chunk_idx = chunk_idx;
322   rf->subdiv_idx = subdiv_idx;
323   rf->sc_idx = field_idx;
324   rf->rank = col_params_->subdiv_rank[subdiv_idx];
325   rf->second_pass = false;
326   rf->action = RF_INIT;
327   // Recv from the device with preceding rank within the subdivision.
328   int recv_from_rank = (rf->rank + (group_size_ - 1)) % group_size_;
329   int send_to_rank = (rf->rank + 1) % group_size_;
330   rf->recv_dev_idx = col_params_->instance.impl_details
331                          .subdiv_permutations[subdiv_idx][recv_from_rank];
332   int send_dev_idx = col_params_->instance.impl_details
333                          .subdiv_permutations[subdiv_idx][send_to_rank];
334   rf->recv_is_remote = !col_params_->task.is_local[rf->recv_dev_idx];
335   rf->send_is_remote = !col_params_->task.is_local[send_dev_idx];
336   if (ca_->ChunkBytes(rf->sc_idx) > 0) {
337     // In pass 0 we skip Recv when rank = chunk_idx
338     rf->do_recv = (rf->chunk_idx != rf->rank);
339     // In pass 0 we skip Send when rank = chunk_idx-1
340     rf->do_send =
341         (rf->rank != ((rf->chunk_idx + (group_size_ - 1)) % group_size_));
342   }
343   rf->is_final =
344       (rf->rank == ((rf->chunk_idx + (group_size_ - 1)) % group_size_));
345   if (rf->do_send || rf->do_recv) {
346     rf->chunk = ca_->ChunkAlias(rf->sc_idx);
347   }
348   VLOG(2) << this << " InitRingField " << rf->DebugString() << " chunk "
349           << ca_->TBounds(rf->chunk);
350 }
351 
352 // When a RingField transitions from first to second recompute the
353 // do_send and do_recv values.
AdvanceToSecondPass(RingField * rf)354 void RingAlg::AdvanceToSecondPass(RingField* rf) {
355   VLOG(3) << "IncrRingField old value " << rf->DebugString();
356   DCHECK(!rf->second_pass);
357   rf->second_pass = true;
358   rf->action = RF_INIT;
359   if (ca_->ChunkBytes(rf->sc_idx) > 0) {
360     // In pass 1 the send/no-send boundary moves down 1 place.
361     rf->do_recv =
362         (rf->rank != ((rf->chunk_idx + (group_size_ - 1)) % group_size_));
363     rf->do_send =
364         (rf->rank != ((rf->chunk_idx + (group_size_ - 2)) % group_size_));
365   }
366   rf->is_final =
367       (rf->rank == ((rf->chunk_idx + (group_size_ - 2)) % group_size_));
368   VLOG(3) << "IncrRingField new value " << rf->DebugString();
369 }
370 
DebugString() const371 string RingAlg::RingField::DebugString() const {
372   string rv = strings::StrCat("RingField rank=", rank, " chunk_idx=", chunk_idx,
373                               " subdiv=", subdiv_idx, " sc_idx=", sc_idx,
374                               " action=", action);
375   strings::StrAppend(&rv, " pass=", second_pass);
376   strings::StrAppend(&rv, " do_send=", do_send, " do_recv=", do_recv,
377                      " is_final=", is_final, " recv_is_remote=", recv_is_remote,
378                      " recv_dev_idx=", recv_dev_idx, " sc_idx=", sc_idx);
379   return rv;
380 }
381 
DispatchSend(RingField * rf,const StatusCallback & done)382 void RingAlg::DispatchSend(RingField* rf, const StatusCallback& done) {
383   DCHECK(rf->do_send);
384   string send_buf_key = RingAlgBufKey(name_, col_ctx_->exec_key,
385                                       rf->second_pass, rf->sc_idx, rf->rank);
386   VLOG(3) << "DispatchSend rank=" << col_params_->default_rank << " send key "
387           << send_buf_key << " chunk " << ca_->TBounds(rf->chunk) << " sc_idx "
388           << rf->sc_idx;
389   int send_to_rank = (rf->rank + 1) % group_size_;
390   int send_to_dev_idx = col_params_->instance.impl_details
391                             .subdiv_permutations[rf->subdiv_idx][send_to_rank];
392   col_ctx_->col_exec->remote_access()->PostToPeer(
393       col_params_->group.device_names[send_to_dev_idx],
394       col_params_->group.task_names[send_to_dev_idx], send_buf_key,
395       col_ctx_->device, col_ctx_->op_ctx->op_device_context(),
396       col_ctx_->op_ctx->output_alloc_attr(0), &rf->chunk,
397       col_ctx_->device_locality, col_ctx_->op_ctx->cancellation_manager(),
398       done);
399 }
400 
DispatchRecv(RingField * rf,const StatusCallback & done)401 void RingAlg::DispatchRecv(RingField* rf, const StatusCallback& done) {
402   DCHECK(rf->do_recv);
403   string recv_buf_key =
404       RingAlgBufKey(name_, col_ctx_->exec_key, rf->second_pass, rf->sc_idx,
405                     (rf->rank + (group_size_ - 1)) % group_size_);
406   VLOG(3) << "DispatchRecv rank=" << col_params_->default_rank << " recv key "
407           << recv_buf_key << " chunk " << ca_->TBounds(rf->chunk) << " into "
408           << ((col_params_->merge_op != nullptr) ? "tmp_chunk" : "chunk");
409   Tensor* dst_tensor = (!rf->second_pass && (col_params_->merge_op != nullptr))
410                            ? &rf->tmp_chunk
411                            : &rf->chunk;
412   col_ctx_->col_exec->remote_access()->RecvFromPeer(
413       col_params_->group.device_names[rf->recv_dev_idx],
414       col_params_->group.task_names[rf->recv_dev_idx],
415       col_params_->task.is_local[rf->recv_dev_idx], recv_buf_key,
416       col_ctx_->device, col_ctx_->op_ctx->op_device_context(),
417       col_ctx_->op_ctx->output_alloc_attr(0), dst_tensor,
418       col_ctx_->device_locality, rf->subdiv_idx,
419       col_ctx_->op_ctx->cancellation_manager(), done);
420 }
421 
FieldState()422 string RingAlg::FieldState() {
423   string s = strings::StrCat(
424       "Ring", name_, " ", strings::Hex(reinterpret_cast<uint64>(this)),
425       " exec ", col_ctx_->exec_key, " step_id=", col_ctx_->step_id,
426       " state of all ", rfv_.size(), " fields:");
427   for (int i = 0; i < rfv_.size(); ++i) {
428     s.append("\n");
429     s.append(rfv_[i].DebugString());
430   }
431   return s;
432 }
433 
434 }  // namespace tensorflow
435