• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h"
16 
17 #include <functional>
18 #include <memory>
19 #include <string>
20 #include <utility>
21 
22 #include "tensorflow/core/common_runtime/collective_rma_local.h"
23 #include "tensorflow/core/common_runtime/collective_util.h"
24 #include "tensorflow/core/common_runtime/device_mgr.h"
25 #include "tensorflow/core/common_runtime/dma_helper.h"
26 #include "tensorflow/core/framework/device_base.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/lib/core/notification.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/lib/strings/str_util.h"
32 #include "tensorflow/core/lib/strings/strcat.h"
33 #include "tensorflow/core/platform/env.h"
34 #include "tensorflow/core/platform/types.h"
35 #include "tensorflow/core/profiler/lib/scoped_memory_debug_annotation.h"
36 #include "tensorflow/core/profiler/lib/traceme.h"
37 
38 // Set true for greater intelligibility of debug mode log messages.
39 #define READABLE_KEYS false
40 
41 namespace tensorflow {
42 
43 namespace {
44 // Key to be used for BufRendezvous by Broadcaster.
BroadcastBufKey(const string & exec_key,int subdiv,int src_rank,int dst_rank)45 string BroadcastBufKey(const string& exec_key, int subdiv, int src_rank,
46                        int dst_rank) {
47   if (READABLE_KEYS) {
48     return strings::StrCat("broadcast(", exec_key, "):subdiv(", subdiv,
49                            "):src(", src_rank, "):dst(", dst_rank, ")");
50   } else {
51     // TODO(b/78352018): Try a denser format, e.g. a 64 or 128 bit hash.
52     return strings::StrCat(exec_key, ":", subdiv, ":", src_rank, ":", dst_rank);
53   }
54 }
55 }  // namespace
56 
HierarchicalTreeBroadcaster()57 HierarchicalTreeBroadcaster::HierarchicalTreeBroadcaster()
58     : col_ctx_(nullptr),
59       col_params_(nullptr),
60       done_(nullptr),
61       is_source_(false) {}
62 
GetDeviceTask(int device_rank,const std::vector<int> & dev_per_task)63 int HierarchicalTreeBroadcaster::GetDeviceTask(
64     int device_rank, const std::vector<int>& dev_per_task) {
65   int num_tasks = static_cast<int>(dev_per_task.size());
66   int task_lo = 0;
67   int task_hi = -1;
68   for (int ti = 0; ti < num_tasks; ti++) {
69     task_hi = task_lo + dev_per_task[ti];
70     if (task_lo <= device_rank && device_rank < task_hi) return ti;
71     task_lo = task_hi;
72   }
73   LOG(FATAL) << "Unexpected device rank " << device_rank << " for " << task_hi
74              << " devices";
75   return -1;
76 }
77 
InitializeCollectiveParams(CollectiveParams * col_params)78 Status HierarchicalTreeBroadcaster::InitializeCollectiveParams(
79     CollectiveParams* col_params) {
80   CHECK_EQ(col_params->instance.type, BROADCAST_COLLECTIVE);
81   CHECK_EQ(col_params->instance.impl_details.collective_name,
82            "HierarchicalTreeBroadcast");
83   const string& device_name =
84       col_params->group.members[col_params->default_rank].device.name();
85   // Start by counting the devices in each task.
86   // Precondition: device_names must be sorted so that all devices in
87   // the same task are adjacent.
88   std::vector<int> dev_per_task;
89   const string* prior_task_name = &col_params->group.members[0].task;
90   int dev_count = 1;
91   for (int di = 1; di < col_params->group.group_size; ++di) {
92     if (col_params->group.members[di].task != *prior_task_name) {
93       dev_per_task.push_back(dev_count);
94       dev_count = 1;
95       prior_task_name = &col_params->group.members[di].task;
96     } else {
97       ++dev_count;
98     }
99   }
100   dev_per_task.push_back(dev_count);
101   CHECK_EQ(col_params->group.num_tasks, dev_per_task.size());
102 
103   if (VLOG_IS_ON(2)) {
104     string dpt_buf;
105     for (int dpt : dev_per_task) strings::StrAppend(&dpt_buf, dpt, ";");
106     VLOG(2) << "HierarchicalTreeBroadcaster::InitializeCollectiveParams device="
107             << device_name << " source_rank=" << col_params->source_rank
108             << " dev_per_task=" << dpt_buf;
109   }
110   int num_tasks = col_params->group.num_tasks;
111   // If there is just 1 task, then execute binary tree broadcast over all
112   // devices.  Otherwise, the first subdiv is inter-task broadcast, and then
113   // there are N more subdivs, where N is #task.
114   int num_subdivs = num_tasks + (num_tasks > 1 ? 1 : 0);
115   int total_num_devices = 0;
116   for (int num_dev : dev_per_task) total_num_devices += num_dev;
117 
118   col_params->instance.impl_details.subdiv_permutations.resize(num_subdivs);
119   col_params->subdiv_rank.reserve(num_subdivs);
120   col_params->instance.impl_details.subdiv_source_rank.reserve(num_subdivs);
121 
122   // Inter-task subdiv.  Pick one device from each task - this is the source
123   // device if it belongs to that task, or device 0 for that task.  If a device
124   // does not participate in the subdiv, set subdiv_rank to -1.
125   if (num_tasks > 1) {
126     const int sdi = 0;
127     std::vector<int>& perm =
128         col_params->instance.impl_details.subdiv_permutations[sdi];
129     CHECK_EQ(perm.size(), 0);
130     int device_count = 0;
131     int source_task = GetDeviceTask(col_params->source_rank, dev_per_task);
132     for (int ti = 0; ti < col_params->group.num_tasks; ti++) {
133       bool participate = false;
134       if (source_task == ti) {
135         // Source device belongs to this task.
136         perm.push_back(col_params->source_rank);
137         participate =
138             col_params->group.members[col_params->source_rank].device.name() ==
139             device_name;
140       } else {
141         // Source does not belong to this task, choose dev 0.
142         perm.push_back(device_count);
143         participate = col_params->group.members[device_count].device.name() ==
144                       device_name;
145       }
146       if (participate) col_params->subdiv_rank.push_back(ti);
147       device_count += dev_per_task[ti];
148     }
149     if (col_params->subdiv_rank.empty()) col_params->subdiv_rank.push_back(-1);
150     col_params->instance.impl_details.subdiv_source_rank.push_back(source_task);
151   }
152   VLOG(2) << collective_util::SubdivPermDebugString(*col_params);
153 
154   // Intra-task subdivs.  Pick all devices in task ti for subdiv sdi.  Set
155   // source to dev 0 for that task if it does not contain original source, else
156   // set to rank of original source.  If a device does not participate in
157   // the subdiv, set subdiv_rank to -1;
158   int abs_di = 0;
159   for (int ti = 0; ti < col_params->group.num_tasks; ti++) {
160     const int sdi = ti + (num_tasks > 1 ? 1 : 0);
161     std::vector<int>& perm =
162         col_params->instance.impl_details.subdiv_permutations[sdi];
163     CHECK_EQ(perm.size(), 0);
164     bool participate = false;
165     int subdiv_source = 0;
166     for (int di = 0; di < dev_per_task[ti]; di++) {
167       perm.push_back(abs_di);
168       if (col_params->group.members[abs_di].device.name() == device_name) {
169         participate = true;
170         col_params->subdiv_rank.push_back(di);
171       }
172       if (abs_di == col_params->source_rank) subdiv_source = di;
173       abs_di++;
174     }
175     if (!participate) col_params->subdiv_rank.push_back(-1);
176     col_params->instance.impl_details.subdiv_source_rank.push_back(
177         subdiv_source);
178   }
179 
180   for (int sri = 0; sri < num_subdivs; sri++) {
181     CHECK_GE(col_params->instance.impl_details.subdiv_source_rank[sri], 0);
182   }
183 
184   VLOG(2) << collective_util::SubdivPermDebugString(*col_params);
185   return OkStatus();
186 }
187 
InitializeCollectiveContext(std::shared_ptr<CollectiveContext> col_ctx)188 Status HierarchicalTreeBroadcaster::InitializeCollectiveContext(
189     std::shared_ptr<CollectiveContext> col_ctx) {
190   CHECK(col_ctx->dev_mgr);
191   col_ctx_ = col_ctx;
192   col_params_ = col_ctx->col_params.get();
193   return collective_util::InitializeDeviceAndLocality(
194       col_ctx->dev_mgr, col_ctx->device_name, &col_ctx->device,
195       &col_ctx->device_locality);
196 }
197 
Run(StatusCallback done)198 void HierarchicalTreeBroadcaster::Run(StatusCallback done) {
199   CHECK(col_ctx_);
200   CHECK(col_params_);
201   done_ = std::move(done);
202   is_source_ = col_params_->is_source;
203   RunTree();
204 }
205 
206 // Binary tree parent/child relations are trivial to calculate, i.e.
207 // device at rank r is the parent of 2r+1 and 2r+2.  The one exception
208 // is if the source is not rank 0.  We treat that case as though the
209 // source is appended to the front of the rank ordering as well as
210 // continuing to occupy its current position.  Hence we calculate as
211 // though each device's rank is actually r+1, then subtract 1 again to
212 // get the descendent ranks.  If the source is not rank 0 then its
213 // descendants include both {0,1} and the descendents of its current
214 // position.  Where a non-0-rank source is a descendent of another
215 // device, no send to it is necessary.
216 
217 /* static*/
TreeRecvFrom(const CollectiveParams & cp,int subdiv)218 int HierarchicalTreeBroadcaster::TreeRecvFrom(const CollectiveParams& cp,
219                                               int subdiv) {
220   DCHECK_LT(subdiv, static_cast<int>(cp.subdiv_rank.size()));
221   int my_rank = cp.subdiv_rank[subdiv];
222   if (-1 == my_rank) return -1;
223 
224   const auto& impl = cp.instance.impl_details;
225   DCHECK_LT(subdiv, static_cast<int>(impl.subdiv_source_rank.size()));
226   int source_rank = impl.subdiv_source_rank[subdiv];
227   if (my_rank == source_rank) return -1;
228   if (source_rank == 0) {
229     return (my_rank - 1) / 2;
230   } else {
231     int predecessor_rank = (my_rank / 2) - 1;
232     return (predecessor_rank < 0) ? source_rank : predecessor_rank;
233   }
234 }
235 
236 /* static */
TreeSendTo(const CollectiveParams & cp,int subdiv,std::vector<int> * targets)237 void HierarchicalTreeBroadcaster::TreeSendTo(const CollectiveParams& cp,
238                                              int subdiv,
239                                              std::vector<int>* targets) {
240   DCHECK_LT(subdiv, static_cast<int>(cp.subdiv_rank.size()));
241   int my_rank = cp.subdiv_rank[subdiv];
242   if (-1 == my_rank) return;
243 
244   const auto& impl = cp.instance.impl_details;
245   DCHECK_LT(subdiv, static_cast<int>(impl.subdiv_source_rank.size()));
246   int source_rank = impl.subdiv_source_rank[subdiv];
247 
248   int group_size = 0;
249   for (int i = 0; i < impl.subdiv_permutations[subdiv].size(); i++) {
250     if (impl.subdiv_permutations[subdiv][i] >= 0) {
251       group_size++;
252     }
253   }
254 
255   targets->clear();
256   int successor_rank = 0;
257   if (source_rank == 0) {
258     successor_rank = (2 * my_rank) + 1;
259   } else {
260     successor_rank = (2 * (my_rank + 1));
261   }
262   DCHECK_NE(successor_rank, my_rank);
263   if (cp.is_source && source_rank != 0) {
264     // The source sends to rank 0,1 in addition to its positional
265     // descendants.
266     if (group_size > 1) {
267       targets->push_back(0);
268     }
269     if (group_size > 2 && source_rank != 1) {
270       targets->push_back(1);
271     }
272   }
273   for (int i = 0; i < 2; ++i) {
274     if (successor_rank < group_size && successor_rank != source_rank) {
275       targets->push_back(successor_rank);
276     }
277     ++successor_rank;
278   }
279 }
280 
281 // Executes a hierarchical tree broadcast.
282 // Each subdiv is a broadcast between a subset of the devices.
283 // If there is only one task, there is one subdiv comprising a broadcast between
284 // all devices belonging to the task.
285 // If there are n tasks, n>1, then there are n+1 subdivs.  In the first (global)
286 // subdiv, one device from each task participates in a binary tree broadcast.
287 // Each task receives a copy of the tensor on one device via this broadcast.
288 // Subsequent subdivs correspond to intra-task broadcasts.  Subdiv i+1
289 // corresponds to broadcast between all devices on task i.  Thus, each task
290 // participates in at most 2 subdivs.
RunTree()291 void HierarchicalTreeBroadcaster::RunTree() {
292   int num_subdivs = static_cast<int>(col_params_->subdiv_rank.size());
293   // TODO(b/78352018): this is easily improved when a node participates in both
294   // first and second subdivision.  It would first send to its descendents in
295   // the first subdiv, then wait until all pending ops are finished before
296   // sending to descendents in second subdiv.  A better implementation would
297   // collapse the two send blocks.
298   for (int si = 0; si < num_subdivs; si++) {
299     int my_rank = col_params_->subdiv_rank[si];
300     // If rank is -1, this device does not participate in this subdiv.
301     if (-1 == my_rank) continue;
302     int source_rank = col_params_->instance.impl_details.subdiv_source_rank[si];
303     if (VLOG_IS_ON(1)) {
304       string subdiv_buf;
305       for (int r : col_params_->instance.impl_details.subdiv_permutations[si]) {
306         strings::StrAppend(&subdiv_buf, r, ",");
307       }
308       VLOG(1) << "Running Broadcast tree device=" << col_ctx_->device_name
309               << " subdiv=" << si << " perm=" << subdiv_buf
310               << " my_rank=" << my_rank << " source_rank=" << source_rank;
311     }
312 
313     mutex mu;               // also guards status_ while callbacks are pending
314     int pending_count = 0;  // TF_GUARDED_BY(mu)
315     condition_variable all_done;
316 
317     if (my_rank >= 0 && my_rank != source_rank) {
318       // Begin by receiving the value.
319       profiler::TraceMe activity(
320           [&] { return strings::StrCat("ReceiveValue:", si); },
321           profiler::TraceMeLevel::kInfo);
322       int recv_from_rank = TreeRecvFrom(*col_params_, si);
323       Notification note;
324       DispatchRecv(si, recv_from_rank, my_rank, col_ctx_->output,
325                    [this, &mu, &note](const Status& s) {
326                      mutex_lock l(mu);
327                      status_.Update(s);
328                      note.Notify();
329                    });
330       note.WaitForNotification();
331     }
332 
333     // Then forward value to all descendent devices.
334     {
335       profiler::TraceMe activity(
336           [&] { return strings::StrCat("ForwardValue:", si); },
337           profiler::TraceMeLevel::kInfo);
338       if (my_rank >= 0 && status_.ok()) {
339         std::vector<int> send_to_ranks;
340         TreeSendTo(*col_params_, si, &send_to_ranks);
341         for (int i = 0; i < send_to_ranks.size(); ++i) {
342           int target_rank = send_to_ranks[i];
343           {
344             mutex_lock l(mu);
345             ++pending_count;
346           }
347           DispatchSend(si, target_rank, my_rank,
348                        (is_source_ ? col_ctx_->input : col_ctx_->output),
349                        [this, &mu, &pending_count, &all_done](const Status& s) {
350                          mutex_lock l(mu);
351                          status_.Update(s);
352                          --pending_count;
353                          if (pending_count == 0) {
354                            all_done.notify_all();
355                          }
356                        });
357         }
358       }
359 
360       // For the original source device, we copy input to output if they are
361       // different.
362       // If there is only 1 subdiv, we do this in that subdiv.  If there is more
363       // than 1 subdiv, then the original source device will participate in 2
364       // subdivs - the global inter-task broadcast and one local intra-task
365       // broadcast.  In this case, we perform the copy in the second subdiv for
366       // this device.
367       if (status_.ok() && is_source_ && (1 == num_subdivs || 0 != si)) {
368         VLOG(2) << "copying input to output for device="
369                 << col_ctx_->device_name << " subdiv=" << si;
370         if (col_ctx_->input != col_ctx_->output &&
371             (DMAHelper::base(col_ctx_->input) !=
372              DMAHelper::base(col_ctx_->output))) {
373           {
374             mutex_lock l(mu);
375             ++pending_count;
376           }
377           DeviceContext* op_dev_ctx = col_ctx_->op_ctx->op_device_context();
378           CollectiveRemoteAccessLocal::MemCpyAsync(
379               op_dev_ctx, op_dev_ctx, col_ctx_->device, col_ctx_->device,
380               col_ctx_->op_ctx->input_alloc_attr(0),
381               col_ctx_->op_ctx->output_alloc_attr(0), col_ctx_->input,
382               col_ctx_->output, 0, /*stream_index*/
383               [this, &mu, &pending_count, &all_done](const Status& s) {
384                 mutex_lock l(mu);
385                 status_.Update(s);
386                 --pending_count;
387                 if (0 == pending_count) {
388                   all_done.notify_all();
389                 }
390               });
391         }
392       }
393 
394       // Then wait for all pending actions to complete.
395       {
396         mutex_lock l(mu);
397         if (pending_count > 0) {
398           all_done.wait(l);
399         }
400       }
401     }
402   }
403   VLOG(2) << "device=" << col_ctx_->device_name << " return status " << status_;
404   done_(status_);
405 }
406 
DispatchSend(int subdiv,int dst_rank,int src_rank,const Tensor * src_tensor,const StatusCallback & done)407 void HierarchicalTreeBroadcaster::DispatchSend(int subdiv, int dst_rank,
408                                                int src_rank,
409                                                const Tensor* src_tensor,
410                                                const StatusCallback& done) {
411   profiler::ScopedMemoryDebugAnnotation op_annotation(
412       col_params_->name.data(), col_ctx_->step_id, "dynamic",
413       src_tensor->dtype(),
414       [src_tensor]() { return src_tensor->shape().DebugString(); });
415   string send_buf_key =
416       BroadcastBufKey(col_ctx_->exec_key, subdiv, src_rank, dst_rank);
417   int dst_idx =
418       col_params_->instance.impl_details.subdiv_permutations[subdiv][dst_rank];
419   VLOG(3) << "DispatchSend " << send_buf_key << " from_device "
420           << col_ctx_->device_name << " to_device "
421           << col_params_->group.members[dst_idx].device.name()
422           << " subdiv=" << subdiv << " dst_rank=" << dst_rank
423           << " dst_idx=" << dst_idx;
424   col_ctx_->col_exec->remote_access()->PostToPeer(
425       col_params_->group.members[dst_idx].device.name(),
426       col_params_->group.members[dst_idx].task, send_buf_key, col_ctx_->device,
427       col_ctx_->op_ctx->op_device_context(),
428       col_ctx_->op_ctx->output_alloc_attr(0), src_tensor,
429       col_ctx_->device_locality, col_ctx_->op_ctx->cancellation_manager(),
430       done);
431 }
432 
DispatchRecv(int subdiv,int src_rank,int dst_rank,Tensor * dst_tensor,const StatusCallback & done)433 void HierarchicalTreeBroadcaster::DispatchRecv(int subdiv, int src_rank,
434                                                int dst_rank, Tensor* dst_tensor,
435                                                const StatusCallback& done) {
436   string recv_buf_key =
437       BroadcastBufKey(col_ctx_->exec_key, subdiv, src_rank, dst_rank);
438   int src_idx =
439       col_params_->instance.impl_details.subdiv_permutations[subdiv][src_rank];
440   VLOG(3) << "DispatchRecv " << recv_buf_key << " from_device "
441           << col_params_->group.members[src_idx].device.name() << " to_device "
442           << col_ctx_->device_name << " subdiv=" << subdiv
443           << " src_rank=" << src_rank << " src_idx=" << src_idx;
444   col_ctx_->col_exec->remote_access()->RecvFromPeer(
445       col_params_->group.members[src_idx].device.name(),
446       col_params_->group.members[src_idx].task,
447       col_params_->group.members[src_idx].is_local, recv_buf_key,
448       col_ctx_->device, col_ctx_->op_ctx->op_device_context(),
449       col_ctx_->op_ctx->output_alloc_attr(0), dst_tensor,
450       col_ctx_->device_locality, 0 /*stream_index*/,
451       col_ctx_->op_ctx->cancellation_manager(), done);
452 }
453 
454 namespace {
455 REGISTER_COLLECTIVE(HierarchicalTreeBroadcast, HierarchicalTreeBroadcaster);
456 }  // namespace
457 
458 }  // namespace tensorflow
459