1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/ring_alg.h"
16
17 #include <stdlib.h>
18
19 #include <atomic>
20 #include <functional>
21 #include <utility>
22
23 #include "tensorflow/core/common_runtime/collective_rma_local.h"
24 #include "tensorflow/core/common_runtime/collective_util.h"
25 #include "tensorflow/core/common_runtime/copy_tensor.h"
26 #include "tensorflow/core/common_runtime/device.h"
27 #include "tensorflow/core/common_runtime/device_mgr.h"
28 #include "tensorflow/core/common_runtime/dma_helper.h"
29 #include "tensorflow/core/common_runtime/process_util.h"
30 #include "tensorflow/core/framework/allocator.h"
31 #include "tensorflow/core/framework/device_base.h"
32 #include "tensorflow/core/framework/op_kernel.h"
33 #include "tensorflow/core/framework/tensor.h"
34 #include "tensorflow/core/framework/types.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/lib/core/notification.h"
37 #include "tensorflow/core/lib/core/status.h"
38 #include "tensorflow/core/lib/strings/str_util.h"
39 #include "tensorflow/core/lib/strings/strcat.h"
40 #include "tensorflow/core/platform/env.h"
41 #include "tensorflow/core/platform/types.h"
42
43 // Set true for greater intelligibility of debug mode log messages.
44 #define READABLE_KEYS false
45 // A ring algorithm exchanges chunks of tensor between devices. The chunk size
46 // depends on the number of subdivisions specified in the algorithm. If the
47 // user does not specify the number of subdivisions we may infer the number
48 // dynamically so that the resulting chunk size does not exceed
49 // kMaxChunkSizeBytes, empirically set at 4 MiB.
50 constexpr size_t kMaxChunkSizeBytes = (4 * 1024 * 1024);
51 // kMaxSubdivsPerDev is used to give an upper bound on the number of
52 // subdivisions dynamically generated. A reasonable value would be a small
53 // multiple of the number of NICs adjacent to each device.
54 constexpr int kMaxSubdivsPerDevice = 2;
55
56 namespace tensorflow {
57 namespace {
58 // Each CollectiveOp implementation is free to define its own
59 // BufRendezvous key format. This function produces the key used by
60 // RingAlg instances. Note that the exec_key will differentiate between
61 // different instances consequently we don't need to further differentiate
62 // between subclasses of RingAlg.
RingAlgBufKey(const string & name,const string & exec_key,int pass,int section,int source_rank)63 string RingAlgBufKey(const string& name, const string& exec_key, int pass,
64 int section, int source_rank) {
65 if (READABLE_KEYS) {
66 return strings::StrCat(name, "(", exec_key, "):pass(", pass, "):section(",
67 section, "):srcrank(", source_rank, ")");
68 } else {
69 // TODO(b/78352018): Try out some kind of denser encoding, e.g. 128 bit
70 // hash.
71 return strings::StrCat(exec_key, ":", pass, ":", section, ":", source_rank);
72 }
73 }
74
75 } // namespace
76
Enqueue(RingField * rf)77 void RingAlg::PCQueue::Enqueue(RingField* rf) {
78 mutex_lock l(pcq_mu_);
79 deque_.push_back(rf);
80 if (waiter_count_ > 0) {
81 cv_.notify_one();
82 }
83 }
84
Dequeue()85 RingAlg::RingField* RingAlg::PCQueue::Dequeue() {
86 mutex_lock l(pcq_mu_);
87 if (deque_.empty()) {
88 ++waiter_count_;
89 while (deque_.empty()) {
90 cv_.wait(l);
91 }
92 --waiter_count_;
93 }
94 RingField* rf = deque_.front();
95 deque_.pop_front();
96 return rf;
97 }
98
RingAlg(CollectiveType type,const string & name)99 RingAlg::RingAlg(CollectiveType type, const string& name)
100 : type_(type),
101 name_(name),
102 col_ctx_(nullptr),
103 col_params_(nullptr),
104 done_(nullptr),
105 group_size_(-1),
106 num_subdivs_(-1) {}
107
108 namespace {
GenerateSubdivsInCollectiveParams(CollectiveParams * col_params)109 Status GenerateSubdivsInCollectiveParams(CollectiveParams* col_params) {
110 if (col_params->instance.shape.num_elements() == 0) {
111 return errors::Internal("shape in CollectiveParams should be non-empty");
112 }
113 const int kAvgDevPerTask =
114 col_params->group.group_size / col_params->group.num_tasks;
115 const int kMaxNumSubdivs = kMaxSubdivsPerDevice * kAvgDevPerTask;
116 if (kMaxNumSubdivs <= 0) {
117 return errors::Internal("Unexpected kMaxNumSubdivs ", kMaxNumSubdivs,
118 " in ",
119 col_params->instance.impl_details.collective_name);
120 }
121 // NOTE(ayushd): If no subdiv_offsets have been specified, dynamically add
122 // as many offsets as needed so that the size of tensor chunks <=
123 // kMaxChunkSizeBytes. Empirically, chunks that are too small or too large
124 // lead to worse performance.
125 int num_subdivs = 0;
126 const size_t tensor_size = col_params->instance.shape.num_elements() *
127 DataTypeSize(col_params->instance.data_type);
128 size_t chunk_size;
129 do {
130 ++num_subdivs;
131 int num_chunks = col_params->group.group_size * num_subdivs;
132 chunk_size = tensor_size / num_chunks;
133 VLOG(2) << "num_subdivs " << num_subdivs << " num_chunks " << num_chunks
134 << " chunk_size " << chunk_size;
135 } while (chunk_size > kMaxChunkSizeBytes && num_subdivs < kMaxNumSubdivs);
136 if (num_subdivs <= 0) {
137 return errors::Internal("Unexpected num_subdivs ", num_subdivs, " in ",
138 col_params->instance.impl_details.collective_name);
139 }
140
141 int subdiv_stride = kAvgDevPerTask / num_subdivs;
142 if (subdiv_stride == 0) subdiv_stride = 1;
143 col_params->instance.impl_details.subdiv_offsets.reserve(num_subdivs);
144 for (int sdi = 0; sdi < num_subdivs; ++sdi) {
145 int subdiv_offset = subdiv_stride * sdi;
146 if (sdi % 2 == 1) subdiv_offset *= -1;
147 col_params->instance.impl_details.subdiv_offsets.push_back(subdiv_offset);
148 }
149
150 if (VLOG_IS_ON(2)) {
151 string subdiv_buf;
152 for (const int subdiv_offset :
153 col_params->instance.impl_details.subdiv_offsets) {
154 strings::StrAppend(&subdiv_buf, " ", subdiv_offset);
155 }
156 VLOG(2) << "Dynamically generated " << num_subdivs
157 << " subdiv_offsets:" << subdiv_buf << " tensor_size "
158 << tensor_size << " chunk_size " << chunk_size;
159 }
160
161 return Status::OK();
162 }
163 } // namespace
164
InitializeCollectiveParams(CollectiveParams * col_params)165 Status RingAlg::InitializeCollectiveParams(CollectiveParams* col_params) {
166 const string& device_name =
167 col_params->group.device_names[col_params->default_rank];
168 // Each subdiv permutation is a ring formed by rotating each
169 // single-task subsequence of devices by an offset. This makes most
170 // sense when each task has the same number of devices but we can't
171 // depend on that being the case so we'll compute something that
172 // works in any case.
173
174 // Start by counting the devices in each task.
175 // Precondition: device_names must be sorted so that all devices in
176 // the same task are adjacent.
177 VLOG(2) << "Sorted task names: "
178 << absl::StrJoin(col_params->group.task_names, ", ");
179 std::vector<int> dev_per_task;
180 const string* prior_task_name = &col_params->group.task_names[0];
181 int dev_count = 1;
182 for (int di = 1; di < col_params->group.group_size; ++di) {
183 if (col_params->group.task_names[di] != *prior_task_name) {
184 dev_per_task.push_back(dev_count);
185 dev_count = 1;
186 prior_task_name = &col_params->group.task_names[di];
187 } else {
188 ++dev_count;
189 }
190 }
191 dev_per_task.push_back(dev_count);
192 DCHECK_EQ(col_params->group.num_tasks, dev_per_task.size());
193
194 if (col_params->instance.impl_details.subdiv_offsets.empty()) {
195 TF_RETURN_IF_ERROR(GenerateSubdivsInCollectiveParams(col_params));
196 }
197
198 // Generate a ring permutation for requested offset.
199 VLOG(2) << "Setting up perms for col_params " << col_params
200 << " subdiv_permutations "
201 << &col_params->instance.impl_details.subdiv_permutations;
202 col_params->instance.impl_details.subdiv_permutations.resize(
203 col_params->instance.impl_details.subdiv_offsets.size());
204 col_params->subdiv_rank.resize(
205 col_params->instance.impl_details.subdiv_offsets.size(), -1);
206 for (int sdi = 0;
207 sdi < col_params->instance.impl_details.subdiv_offsets.size(); ++sdi) {
208 std::vector<int>& perm =
209 col_params->instance.impl_details.subdiv_permutations[sdi];
210 DCHECK_EQ(perm.size(), 0);
211 int offset = col_params->instance.impl_details.subdiv_offsets[sdi];
212 // A negative subdivision offset is interpreted as follows:
213 // 1. Reverse the local device ordering.
214 // 2. Begin the subdivision at abs(offset) in the reversed ordering.
215 bool reverse = false;
216 if (offset < 0) {
217 offset = abs(offset);
218 reverse = true;
219 }
220 int prior_dev_count = 0; // sum over prior worker device counts
221 for (int ti = 0; ti < col_params->group.num_tasks; ++ti) {
222 for (int di = 0; di < dev_per_task[ti]; ++di) {
223 int di_offset = (di + offset) % dev_per_task[ti];
224 int offset_di =
225 reverse ? (dev_per_task[ti] - (di_offset + 1)) : di_offset;
226 // Device index in global subdivision permutation.
227 int permuted_di = prior_dev_count + offset_di;
228 int rank = static_cast<int>(perm.size());
229 perm.push_back(permuted_di);
230 if (col_params->group.device_names[permuted_di] == device_name) {
231 DCHECK_EQ(permuted_di, col_params->default_rank);
232 col_params->subdiv_rank[sdi] = rank;
233 }
234 }
235 prior_dev_count += dev_per_task[ti];
236 }
237 DCHECK_EQ(col_params->group.group_size, perm.size());
238 }
239
240 VLOG(2) << collective_util::SubdivPermDebugString(*col_params);
241 return Status::OK();
242 }
243
InitializeCollectiveContext(std::shared_ptr<CollectiveContext> col_ctx)244 Status RingAlg::InitializeCollectiveContext(
245 std::shared_ptr<CollectiveContext> col_ctx) {
246 DCHECK(col_ctx->dev_mgr);
247 col_ctx_ = col_ctx;
248 col_params_ = col_ctx->col_params;
249 return collective_util::InitializeDeviceAndLocality(
250 col_ctx->dev_mgr, col_ctx->device_name, &col_ctx->device,
251 &col_ctx->device_locality);
252 }
253
TensorDebugString(const Tensor & tensor)254 string RingAlg::TensorDebugString(const Tensor& tensor) {
255 const DeviceBase::GpuDeviceInfo* gpu_device_info =
256 col_ctx_->op_ctx->device()->tensorflow_gpu_device_info();
257 if (gpu_device_info) {
258 Tensor cpu_tensor(tensor.dtype(), tensor.shape());
259 Status st = gpu_device_info->default_context->CopyDeviceTensorToCPUSync(
260 &tensor, "" /*tensor_name*/, col_ctx_->device, &cpu_tensor);
261 DCHECK(st.ok());
262 return cpu_tensor.SummarizeValue(64);
263 } else {
264 return tensor.SummarizeValue(64);
265 }
266 }
267
StartAbort(const Status & s)268 void RingAlg::StartAbort(const Status& s) {
269 // In abort mode we stop issuing additional ProvideBuf
270 // and ConsumeBuf calls, but we need to wait for all of the
271 // outstanding callbacks to be invoked before quitting.
272 bool abort_started = false;
273 {
274 mutex_lock l(status_mu_);
275 if (status_.ok()) {
276 LOG(ERROR) << "Aborting Ring" << name_ << " with " << s;
277 abort_started = true;
278 status_.Update(s);
279 }
280 }
281 // If this is the initial entry to abort mode and it's not a cancellation,
282 // then invoke StartAbort on the CollectiveExecutor that invoked us. That
283 // should start cancellation on all of the outstanding CollectiveRemoteAccess
284 // actions. If it's cancellation all pending send/recv should be cancelled as
285 // well and there's then no need to abort.
286 if (abort_started) {
287 if (col_ctx_->op_ctx->cancellation_manager() == nullptr ||
288 (!col_ctx_->op_ctx->cancellation_manager()->IsCancelled() &&
289 !col_ctx_->op_ctx->cancellation_manager()->IsCancelling())) {
290 col_ctx_->col_exec->StartAbort(s);
291 }
292 }
293 }
294
Finish(bool ok)295 void RingAlg::Finish(bool ok) {
296 if (ok) {
297 // Recover the output from the adaptor.
298 ca_->ConsumeFinalValue(col_ctx_->output);
299 }
300 Status s;
301 {
302 mutex_lock l(status_mu_);
303 s = status_;
304 }
305 rfv_.clear(); // Give up Refs on output tensor.
306 done_(s);
307 }
308
309 // At the beginning of the algorithm initialize a RingField struct for
310 // every independent field of the tensor.
InitRingField(RingField * rf,int chunk_idx,int subdiv_idx,int field_idx)311 void RingAlg::InitRingField(RingField* rf, int chunk_idx, int subdiv_idx,
312 int field_idx) {
313 // Note on field indexing: There are group_size_ devices in the
314 // instance, implying the same number of chunks per tensor, where a
315 // chunk is the unit of data transferred in a time step. However, if
316 // a device can simultaneously send data by 2 or more independent
317 // channels we can speed up the transfer by subdividing chunks and
318 // processing multiple subdivisions at once. So the actual number
319 // of RingFields is group_size_ * num_subdivs_.
320 DCHECK_EQ(field_idx, (chunk_idx * num_subdivs_) + subdiv_idx);
321 rf->chunk_idx = chunk_idx;
322 rf->subdiv_idx = subdiv_idx;
323 rf->sc_idx = field_idx;
324 rf->rank = col_params_->subdiv_rank[subdiv_idx];
325 rf->second_pass = false;
326 rf->action = RF_INIT;
327 // Recv from the device with preceding rank within the subdivision.
328 int recv_from_rank = (rf->rank + (group_size_ - 1)) % group_size_;
329 int send_to_rank = (rf->rank + 1) % group_size_;
330 rf->recv_dev_idx = col_params_->instance.impl_details
331 .subdiv_permutations[subdiv_idx][recv_from_rank];
332 int send_dev_idx = col_params_->instance.impl_details
333 .subdiv_permutations[subdiv_idx][send_to_rank];
334 rf->recv_is_remote = !col_params_->task.is_local[rf->recv_dev_idx];
335 rf->send_is_remote = !col_params_->task.is_local[send_dev_idx];
336 if (ca_->ChunkBytes(rf->sc_idx) > 0) {
337 // In pass 0 we skip Recv when rank = chunk_idx
338 rf->do_recv = (rf->chunk_idx != rf->rank);
339 // In pass 0 we skip Send when rank = chunk_idx-1
340 rf->do_send =
341 (rf->rank != ((rf->chunk_idx + (group_size_ - 1)) % group_size_));
342 }
343 rf->is_final =
344 (rf->rank == ((rf->chunk_idx + (group_size_ - 1)) % group_size_));
345 if (rf->do_send || rf->do_recv) {
346 rf->chunk = ca_->ChunkAlias(rf->sc_idx);
347 }
348 VLOG(2) << this << " InitRingField " << rf->DebugString() << " chunk "
349 << ca_->TBounds(rf->chunk);
350 }
351
352 // When a RingField transitions from first to second recompute the
353 // do_send and do_recv values.
AdvanceToSecondPass(RingField * rf)354 void RingAlg::AdvanceToSecondPass(RingField* rf) {
355 VLOG(3) << "IncrRingField old value " << rf->DebugString();
356 DCHECK(!rf->second_pass);
357 rf->second_pass = true;
358 rf->action = RF_INIT;
359 if (ca_->ChunkBytes(rf->sc_idx) > 0) {
360 // In pass 1 the send/no-send boundary moves down 1 place.
361 rf->do_recv =
362 (rf->rank != ((rf->chunk_idx + (group_size_ - 1)) % group_size_));
363 rf->do_send =
364 (rf->rank != ((rf->chunk_idx + (group_size_ - 2)) % group_size_));
365 }
366 rf->is_final =
367 (rf->rank == ((rf->chunk_idx + (group_size_ - 2)) % group_size_));
368 VLOG(3) << "IncrRingField new value " << rf->DebugString();
369 }
370
DebugString() const371 string RingAlg::RingField::DebugString() const {
372 string rv = strings::StrCat("RingField rank=", rank, " chunk_idx=", chunk_idx,
373 " subdiv=", subdiv_idx, " sc_idx=", sc_idx,
374 " action=", action);
375 strings::StrAppend(&rv, " pass=", second_pass);
376 strings::StrAppend(&rv, " do_send=", do_send, " do_recv=", do_recv,
377 " is_final=", is_final, " recv_is_remote=", recv_is_remote,
378 " recv_dev_idx=", recv_dev_idx, " sc_idx=", sc_idx);
379 return rv;
380 }
381
DispatchSend(RingField * rf,const StatusCallback & done)382 void RingAlg::DispatchSend(RingField* rf, const StatusCallback& done) {
383 DCHECK(rf->do_send);
384 string send_buf_key = RingAlgBufKey(name_, col_ctx_->exec_key,
385 rf->second_pass, rf->sc_idx, rf->rank);
386 VLOG(3) << "DispatchSend rank=" << col_params_->default_rank << " send key "
387 << send_buf_key << " chunk " << ca_->TBounds(rf->chunk) << " sc_idx "
388 << rf->sc_idx;
389 int send_to_rank = (rf->rank + 1) % group_size_;
390 int send_to_dev_idx = col_params_->instance.impl_details
391 .subdiv_permutations[rf->subdiv_idx][send_to_rank];
392 col_ctx_->col_exec->remote_access()->PostToPeer(
393 col_params_->group.device_names[send_to_dev_idx],
394 col_params_->group.task_names[send_to_dev_idx], send_buf_key,
395 col_ctx_->device, col_ctx_->op_ctx->op_device_context(),
396 col_ctx_->op_ctx->output_alloc_attr(0), &rf->chunk,
397 col_ctx_->device_locality, col_ctx_->op_ctx->cancellation_manager(),
398 done);
399 }
400
DispatchRecv(RingField * rf,const StatusCallback & done)401 void RingAlg::DispatchRecv(RingField* rf, const StatusCallback& done) {
402 DCHECK(rf->do_recv);
403 string recv_buf_key =
404 RingAlgBufKey(name_, col_ctx_->exec_key, rf->second_pass, rf->sc_idx,
405 (rf->rank + (group_size_ - 1)) % group_size_);
406 VLOG(3) << "DispatchRecv rank=" << col_params_->default_rank << " recv key "
407 << recv_buf_key << " chunk " << ca_->TBounds(rf->chunk) << " into "
408 << ((col_params_->merge_op != nullptr) ? "tmp_chunk" : "chunk");
409 Tensor* dst_tensor = (!rf->second_pass && (col_params_->merge_op != nullptr))
410 ? &rf->tmp_chunk
411 : &rf->chunk;
412 col_ctx_->col_exec->remote_access()->RecvFromPeer(
413 col_params_->group.device_names[rf->recv_dev_idx],
414 col_params_->group.task_names[rf->recv_dev_idx],
415 col_params_->task.is_local[rf->recv_dev_idx], recv_buf_key,
416 col_ctx_->device, col_ctx_->op_ctx->op_device_context(),
417 col_ctx_->op_ctx->output_alloc_attr(0), dst_tensor,
418 col_ctx_->device_locality, rf->subdiv_idx,
419 col_ctx_->op_ctx->cancellation_manager(), done);
420 }
421
FieldState()422 string RingAlg::FieldState() {
423 string s = strings::StrCat(
424 "Ring", name_, " ", strings::Hex(reinterpret_cast<uint64>(this)),
425 " exec ", col_ctx_->exec_key, " step_id=", col_ctx_->step_id,
426 " state of all ", rfv_.size(), " fields:");
427 for (int i = 0; i < rfv_.size(); ++i) {
428 s.append("\n");
429 s.append(rfv_[i].DebugString());
430 }
431 return s;
432 }
433
434 } // namespace tensorflow
435