1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h"
16
17 #include <functional>
18 #include <memory>
19 #include <string>
20 #include <utility>
21
22 #include "tensorflow/core/common_runtime/collective_rma_local.h"
23 #include "tensorflow/core/common_runtime/collective_util.h"
24 #include "tensorflow/core/common_runtime/device_mgr.h"
25 #include "tensorflow/core/common_runtime/dma_helper.h"
26 #include "tensorflow/core/framework/device_base.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/lib/core/notification.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/lib/strings/str_util.h"
32 #include "tensorflow/core/lib/strings/strcat.h"
33 #include "tensorflow/core/platform/env.h"
34 #include "tensorflow/core/platform/types.h"
35
36 // Set true for greater intelligibility of debug mode log messages.
37 #define READABLE_KEYS false
38
39 namespace tensorflow {
40
41 namespace {
42 // Key to be used for BufRendezvous by Broadcaster.
BroadcastBufKey(const string & exec_key,int subdiv,int src_rank,int dst_rank)43 string BroadcastBufKey(const string& exec_key, int subdiv, int src_rank,
44 int dst_rank) {
45 if (READABLE_KEYS) {
46 return strings::StrCat("broadcast(", exec_key, "):subdiv(", subdiv,
47 "):src(", src_rank, "):dst(", dst_rank, ")");
48 } else {
49 // TODO(b/78352018): Try a denser format, e.g. a 64 or 128 bit hash.
50 return strings::StrCat(exec_key, ":", subdiv, ":", src_rank, ":", dst_rank);
51 }
52 }
53 } // namespace
54
HierarchicalTreeBroadcaster()55 HierarchicalTreeBroadcaster::HierarchicalTreeBroadcaster()
56 : col_ctx_(nullptr),
57 col_params_(nullptr),
58 done_(nullptr),
59 is_source_(false) {}
60
GetDeviceTask(int device_rank,const std::vector<int> & dev_per_task)61 int HierarchicalTreeBroadcaster::GetDeviceTask(
62 int device_rank, const std::vector<int>& dev_per_task) {
63 int num_tasks = static_cast<int>(dev_per_task.size());
64 int task_lo = 0;
65 int task_hi = -1;
66 for (int ti = 0; ti < num_tasks; ti++) {
67 task_hi = task_lo + dev_per_task[ti];
68 if (task_lo <= device_rank && device_rank < task_hi) return ti;
69 task_lo = task_hi;
70 }
71 LOG(FATAL) << "Unexpected device rank " << device_rank << " for " << task_hi
72 << " devices";
73 return -1;
74 }
75
InitializeCollectiveParams(CollectiveParams * col_params)76 Status HierarchicalTreeBroadcaster::InitializeCollectiveParams(
77 CollectiveParams* col_params) {
78 CHECK_EQ(col_params->instance.type, BROADCAST_COLLECTIVE);
79 CHECK_EQ(col_params->instance.impl_details.collective_name,
80 "HierarchicalTreeBroadcast");
81 const string& device_name =
82 col_params->instance.device_names[col_params->default_rank];
83 // Start by counting the devices in each task.
84 // Precondition: device_names must be sorted so that all devices in
85 // the same task are adjacent.
86 VLOG(2) << "Sorted task names: "
87 << str_util::Join(col_params->instance.task_names, ", ");
88 std::vector<int> dev_per_task;
89 const string* prior_task_name = &col_params->instance.task_names[0];
90 int dev_count = 1;
91 for (int di = 1; di < col_params->group.group_size; ++di) {
92 if (col_params->instance.task_names[di] != *prior_task_name) {
93 dev_per_task.push_back(dev_count);
94 dev_count = 1;
95 prior_task_name = &col_params->instance.task_names[di];
96 } else {
97 ++dev_count;
98 }
99 }
100 dev_per_task.push_back(dev_count);
101 CHECK_EQ(col_params->group.num_tasks, dev_per_task.size());
102
103 if (VLOG_IS_ON(2)) {
104 string dpt_buf;
105 for (int dpt : dev_per_task) strings::StrAppend(&dpt_buf, dpt, ";");
106 VLOG(2) << "HierarchicalTreeBroadcaster::InitializeCollectiveParams device="
107 << device_name << " source_rank=" << col_params->source_rank
108 << " dev_per_task=" << dpt_buf;
109 }
110 int num_tasks = col_params->group.num_tasks;
111 // If there is just 1 task, then execute binary tree broadcast over all
112 // devices. Otherwise, the first subdiv is inter-task broadcast, and then
113 // there are N more subdivs, where N is #task.
114 int num_subdivs = num_tasks + (num_tasks > 1 ? 1 : 0);
115 int total_num_devices = 0;
116 for (int num_dev : dev_per_task) total_num_devices += num_dev;
117
118 col_params->instance.impl_details.subdiv_permutations.resize(num_subdivs);
119 col_params->subdiv_rank.reserve(num_subdivs);
120 col_params->instance.impl_details.subdiv_source_rank.reserve(num_subdivs);
121
122 // Inter-task subdiv. Pick one device from each task - this is the source
123 // device if it belongs to that task, or device 0 for that task. If a device
124 // does not participate in the subdiv, set subdiv_rank to -1.
125 if (num_tasks > 1) {
126 const int sdi = 0;
127 std::vector<int>& perm =
128 col_params->instance.impl_details.subdiv_permutations[sdi];
129 CHECK_EQ(perm.size(), 0);
130 int device_count = 0;
131 int source_task = GetDeviceTask(col_params->source_rank, dev_per_task);
132 for (int ti = 0; ti < col_params->group.num_tasks; ti++) {
133 bool participate = false;
134 if (source_task == ti) {
135 // Source device belongs to this task.
136 perm.push_back(col_params->source_rank);
137 participate =
138 col_params->instance.device_names[col_params->source_rank] ==
139 device_name;
140 } else {
141 // Source does not belong to this task, choose dev 0.
142 perm.push_back(device_count);
143 participate =
144 col_params->instance.device_names[device_count] == device_name;
145 }
146 if (participate) col_params->subdiv_rank.push_back(ti);
147 device_count += dev_per_task[ti];
148 }
149 if (col_params->subdiv_rank.empty()) col_params->subdiv_rank.push_back(-1);
150 col_params->instance.impl_details.subdiv_source_rank.push_back(source_task);
151 }
152
153 // Intra-task subdivs. Pick all devices in task ti for subdiv sdi. Set
154 // source to dev 0 for that task if it does not contain original source, else
155 // set to rank of original source. If a device does not participate in
156 // the subdiv, set subdiv_rank to -1;
157 int abs_di = 0;
158 for (int ti = 0; ti < col_params->group.num_tasks; ti++) {
159 const int sdi = ti + (num_tasks > 1 ? 1 : 0);
160 std::vector<int>& perm =
161 col_params->instance.impl_details.subdiv_permutations[sdi];
162 CHECK_EQ(perm.size(), 0);
163 bool participate = false;
164 int subdiv_source = 0;
165 for (int di = 0; di < dev_per_task[ti]; di++) {
166 perm.push_back(abs_di);
167 if (col_params->instance.device_names[abs_di] == device_name) {
168 participate = true;
169 col_params->subdiv_rank.push_back(di);
170 }
171 if (abs_di == col_params->source_rank) subdiv_source = di;
172 abs_di++;
173 }
174 if (!participate) col_params->subdiv_rank.push_back(-1);
175 col_params->instance.impl_details.subdiv_source_rank.push_back(
176 subdiv_source);
177 }
178
179 for (int sri = 0; sri < num_subdivs; sri++) {
180 CHECK_GE(col_params->instance.impl_details.subdiv_source_rank[sri], 0);
181 }
182
183 VLOG(2) << collective_util::SubdivPermDebugString(*col_params);
184 return Status::OK();
185 }
186
InitializeCollectiveContext(CollectiveContext * col_ctx)187 Status HierarchicalTreeBroadcaster::InitializeCollectiveContext(
188 CollectiveContext* col_ctx) {
189 CHECK(col_ctx->dev_mgr);
190 col_ctx_ = col_ctx;
191 col_params_ = &col_ctx->col_params;
192 return collective_util::InitializeDeviceAndLocality(
193 col_ctx->dev_mgr, col_ctx->device_name, &col_ctx->device,
194 &col_ctx->device_locality);
195 }
196
Run(StatusCallback done)197 void HierarchicalTreeBroadcaster::Run(StatusCallback done) {
198 CHECK(col_ctx_);
199 CHECK(col_params_);
200 done_ = std::move(done);
201 is_source_ = col_params_->is_source;
202 RunTree();
203 }
204
205 // Binary tree parent/child relations are trivial to calculate, i.e.
206 // device at rank r is the parent of 2r+1 and 2r+2. The one exception
207 // is if the source is not rank 0. We treat that case as though the
208 // source is appended to the front of the rank ordering as well as
209 // continuing to occupy its current position. Hence we calculate as
210 // though each device's rank is actually r+1, then subtract 1 again to
211 // get the descendent ranks. If the source is not rank 0 then its
212 // descendants include both {0,1} and the descendents of its current
213 // position. Where a non-0-rank source is a descendent of another
214 // device, no send to it is necessary.
215
216 /* static*/
TreeRecvFrom(const CollectiveParams & cp,int subdiv)217 int HierarchicalTreeBroadcaster::TreeRecvFrom(const CollectiveParams& cp,
218 int subdiv) {
219 DCHECK_LT(subdiv, static_cast<int>(cp.subdiv_rank.size()));
220 int my_rank = cp.subdiv_rank[subdiv];
221 if (-1 == my_rank) return -1;
222
223 const auto& impl = cp.instance.impl_details;
224 DCHECK_LT(subdiv, static_cast<int>(impl.subdiv_source_rank.size()));
225 int source_rank = impl.subdiv_source_rank[subdiv];
226 if (my_rank == source_rank) return -1;
227 if (source_rank == 0) {
228 return (my_rank - 1) / 2;
229 } else {
230 int predecessor_rank = (my_rank / 2) - 1;
231 return (predecessor_rank < 0) ? source_rank : predecessor_rank;
232 }
233 }
234
235 /* static */
TreeSendTo(const CollectiveParams & cp,int subdiv,std::vector<int> * targets)236 void HierarchicalTreeBroadcaster::TreeSendTo(const CollectiveParams& cp,
237 int subdiv,
238 std::vector<int>* targets) {
239 DCHECK_LT(subdiv, static_cast<int>(cp.subdiv_rank.size()));
240 int my_rank = cp.subdiv_rank[subdiv];
241 if (-1 == my_rank) return;
242
243 const auto& impl = cp.instance.impl_details;
244 DCHECK_LT(subdiv, static_cast<int>(impl.subdiv_source_rank.size()));
245 int source_rank = impl.subdiv_source_rank[subdiv];
246
247 int group_size = 0;
248 for (int i = 0; i < impl.subdiv_permutations[subdiv].size(); i++) {
249 if (impl.subdiv_permutations[subdiv][i] >= 0) {
250 group_size++;
251 }
252 }
253
254 targets->clear();
255 int successor_rank = 0;
256 if (source_rank == 0) {
257 successor_rank = (2 * my_rank) + 1;
258 } else {
259 successor_rank = (2 * (my_rank + 1));
260 }
261 DCHECK_NE(successor_rank, my_rank);
262 if (cp.is_source && source_rank != 0) {
263 // The source sends to rank 0,1 in addition to its positional
264 // descendants.
265 if (group_size > 1) {
266 targets->push_back(0);
267 }
268 if (group_size > 2 && source_rank != 1) {
269 targets->push_back(1);
270 }
271 }
272 for (int i = 0; i < 2; ++i) {
273 if (successor_rank < group_size && successor_rank != source_rank) {
274 targets->push_back(successor_rank);
275 }
276 ++successor_rank;
277 }
278 }
279
280 // Executes a hierarchical tree broadcast.
281 // Each subdiv is a broadcast between a subset of the devices.
282 // If there is only one task, there is one subdiv comprising a broadcast between
283 // all devices belonging to the task.
284 // If there are n tasks, n>1, then there are n+1 subdivs. In the first (global)
285 // subdiv, one device from each task participates in a binary tree broadcast.
286 // Each task receives a copy of the tensor on one device via this broadcast.
287 // Subsequent subdivs correspond to intra-task broadcasts. Subdiv i+1
288 // corresponds to broadcast between all devices on task i. Thus, each task
289 // participates in at most 2 subdivs.
RunTree()290 void HierarchicalTreeBroadcaster::RunTree() {
291 int num_subdivs = static_cast<int>(col_params_->subdiv_rank.size());
292 // TODO(b/78352018): this is easily improved when a node participates in both
293 // first and second subdivision. It would first send to its descendents in
294 // the first subdiv, then wait until all pending ops are finished before
295 // sending to descendents in second subdiv. A better implementation would
296 // collapse the two send blocks.
297 for (int si = 0; si < num_subdivs; si++) {
298 int my_rank = col_params_->subdiv_rank[si];
299 // If rank is -1, this device does not participate in this subdiv.
300 if (-1 == my_rank) continue;
301 int source_rank = col_params_->instance.impl_details.subdiv_source_rank[si];
302 if (VLOG_IS_ON(1)) {
303 string subdiv_buf;
304 for (int r : col_params_->instance.impl_details.subdiv_permutations[si]) {
305 strings::StrAppend(&subdiv_buf, r, ",");
306 }
307 VLOG(1) << "Running Broadcast tree device=" << col_ctx_->device_name
308 << " subdiv=" << si << " perm=" << subdiv_buf
309 << " my_rank=" << my_rank << " source_rank=" << source_rank;
310 }
311
312 mutex mu; // also guards status_ while callbacks are pending
313 int pending_count = 0; // GUARDED_BY(mu)
314 condition_variable all_done;
315
316 if (my_rank >= 0 && my_rank != source_rank) {
317 // Begin by receiving the value.
318 int recv_from_rank = TreeRecvFrom(*col_params_, si);
319 Notification note;
320 DispatchRecv(si, recv_from_rank, my_rank, col_ctx_->output,
321 [this, &mu, ¬e](const Status& s) {
322 mutex_lock l(mu);
323 status_.Update(s);
324 note.Notify();
325 });
326 note.WaitForNotification();
327 }
328
329 // Then forward value to all descendent devices.
330 if (my_rank >= 0 && status_.ok()) {
331 std::vector<int> send_to_ranks;
332 TreeSendTo(*col_params_, si, &send_to_ranks);
333 for (int i = 0; i < send_to_ranks.size(); ++i) {
334 int target_rank = send_to_ranks[i];
335 {
336 mutex_lock l(mu);
337 ++pending_count;
338 }
339 DispatchSend(si, target_rank, my_rank,
340 (is_source_ ? col_ctx_->input : col_ctx_->output),
341 [this, &mu, &pending_count, &all_done](const Status& s) {
342 mutex_lock l(mu);
343 status_.Update(s);
344 --pending_count;
345 if (pending_count == 0) {
346 all_done.notify_all();
347 }
348 });
349 }
350 }
351
352 // For the original source device, we copy input to output if they are
353 // different.
354 // If there is only 1 subdiv, we do this in that subdiv. If there is more
355 // than 1 subdiv, then the original source device will participate in 2
356 // subdivs - the global inter-task broadcast and one local intra-task
357 // broadcast. In this case, we perform the copy in the second subdiv for
358 // this device.
359 if (status_.ok() && is_source_ && (1 == num_subdivs || 0 != si)) {
360 VLOG(2) << "copying input to output for device=" << col_ctx_->device_name
361 << " subdiv=" << si;
362 if (col_ctx_->input != col_ctx_->output &&
363 (DMAHelper::base(col_ctx_->input) !=
364 DMAHelper::base(col_ctx_->output))) {
365 {
366 mutex_lock l(mu);
367 ++pending_count;
368 }
369 DeviceContext* op_dev_ctx = col_ctx_->op_ctx->op_device_context();
370 CollectiveRemoteAccessLocal::MemCpyAsync(
371 op_dev_ctx, op_dev_ctx, col_ctx_->device, col_ctx_->device,
372 col_ctx_->op_ctx->input_alloc_attr(0),
373 col_ctx_->op_ctx->output_alloc_attr(0), col_ctx_->input,
374 col_ctx_->output, 0, /*stream_index*/
375 [this, &mu, &pending_count, &all_done](const Status& s) {
376 mutex_lock l(mu);
377 status_.Update(s);
378 --pending_count;
379 if (0 == pending_count) {
380 all_done.notify_all();
381 }
382 });
383 }
384 }
385
386 // Then wait for all pending actions to complete.
387 {
388 mutex_lock l(mu);
389 if (pending_count > 0) {
390 all_done.wait(l);
391 }
392 }
393 }
394 VLOG(2) << "device=" << col_ctx_->device_name << " return status " << status_;
395 done_(status_);
396 }
397
DispatchSend(int subdiv,int dst_rank,int src_rank,const Tensor * src_tensor,const StatusCallback & done)398 void HierarchicalTreeBroadcaster::DispatchSend(int subdiv, int dst_rank,
399 int src_rank,
400 const Tensor* src_tensor,
401 const StatusCallback& done) {
402 string send_buf_key =
403 BroadcastBufKey(col_ctx_->exec_key, subdiv, src_rank, dst_rank);
404 int dst_idx =
405 col_params_->instance.impl_details.subdiv_permutations[subdiv][dst_rank];
406 VLOG(3) << "DispatchSend " << send_buf_key << " from_device "
407 << col_ctx_->device_name << " to_device "
408 << col_params_->instance.device_names[dst_idx] << " subdiv=" << subdiv
409 << " dst_rank=" << dst_rank << " dst_idx=" << dst_idx;
410 col_ctx_->col_exec->PostToPeer(col_params_->instance.device_names[dst_idx],
411 col_params_->instance.task_names[dst_idx],
412 send_buf_key, col_ctx_->device,
413 col_ctx_->op_ctx->op_device_context(),
414 col_ctx_->op_ctx->output_alloc_attr(0),
415 src_tensor, col_ctx_->device_locality, done);
416 }
417
DispatchRecv(int subdiv,int src_rank,int dst_rank,Tensor * dst_tensor,const StatusCallback & done)418 void HierarchicalTreeBroadcaster::DispatchRecv(int subdiv, int src_rank,
419 int dst_rank, Tensor* dst_tensor,
420 const StatusCallback& done) {
421 string recv_buf_key =
422 BroadcastBufKey(col_ctx_->exec_key, subdiv, src_rank, dst_rank);
423 int src_idx =
424 col_params_->instance.impl_details.subdiv_permutations[subdiv][src_rank];
425 VLOG(3) << "DispatchRecv " << recv_buf_key << " from_device "
426 << col_params_->instance.device_names[src_idx] << " to_device "
427 << col_ctx_->device_name << " subdiv=" << subdiv
428 << " src_rank=" << src_rank << " src_idx=" << src_idx;
429 col_ctx_->col_exec->RecvFromPeer(
430 col_params_->instance.device_names[src_idx],
431 col_params_->instance.task_names[src_idx],
432 col_params_->task.is_local[src_idx], recv_buf_key, col_ctx_->device,
433 col_ctx_->op_ctx->op_device_context(),
434 col_ctx_->op_ctx->output_alloc_attr(0), dst_tensor,
435 col_ctx_->device_locality, 0 /*stream_index*/, done);
436 }
437
438 REGISTER_COLLECTIVE(HierarchicalTreeBroadcast, HierarchicalTreeBroadcaster);
439
440 } // namespace tensorflow
441