• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h"
16 
17 #include <functional>
18 #include <memory>
19 #include <string>
20 #include <utility>
21 
22 #include "tensorflow/core/common_runtime/collective_rma_local.h"
23 #include "tensorflow/core/common_runtime/collective_util.h"
24 #include "tensorflow/core/common_runtime/device_mgr.h"
25 #include "tensorflow/core/common_runtime/dma_helper.h"
26 #include "tensorflow/core/framework/device_base.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/lib/core/notification.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/lib/strings/str_util.h"
32 #include "tensorflow/core/lib/strings/strcat.h"
33 #include "tensorflow/core/platform/env.h"
34 #include "tensorflow/core/platform/types.h"
35 
36 // Set true for greater intelligibility of debug mode log messages.
37 #define READABLE_KEYS false
38 
39 namespace tensorflow {
40 
41 namespace {
42 // Key to be used for BufRendezvous by Broadcaster.
BroadcastBufKey(const string & exec_key,int subdiv,int src_rank,int dst_rank)43 string BroadcastBufKey(const string& exec_key, int subdiv, int src_rank,
44                        int dst_rank) {
45   if (READABLE_KEYS) {
46     return strings::StrCat("broadcast(", exec_key, "):subdiv(", subdiv,
47                            "):src(", src_rank, "):dst(", dst_rank, ")");
48   } else {
49     // TODO(b/78352018): Try a denser format, e.g. a 64 or 128 bit hash.
50     return strings::StrCat(exec_key, ":", subdiv, ":", src_rank, ":", dst_rank);
51   }
52 }
53 }  // namespace
54 
HierarchicalTreeBroadcaster()55 HierarchicalTreeBroadcaster::HierarchicalTreeBroadcaster()
56     : col_ctx_(nullptr),
57       col_params_(nullptr),
58       done_(nullptr),
59       is_source_(false) {}
60 
GetDeviceTask(int device_rank,const std::vector<int> & dev_per_task)61 int HierarchicalTreeBroadcaster::GetDeviceTask(
62     int device_rank, const std::vector<int>& dev_per_task) {
63   int num_tasks = static_cast<int>(dev_per_task.size());
64   int task_lo = 0;
65   int task_hi = -1;
66   for (int ti = 0; ti < num_tasks; ti++) {
67     task_hi = task_lo + dev_per_task[ti];
68     if (task_lo <= device_rank && device_rank < task_hi) return ti;
69     task_lo = task_hi;
70   }
71   LOG(FATAL) << "Unexpected device rank " << device_rank << " for " << task_hi
72              << " devices";
73   return -1;
74 }
75 
InitializeCollectiveParams(CollectiveParams * col_params)76 Status HierarchicalTreeBroadcaster::InitializeCollectiveParams(
77     CollectiveParams* col_params) {
78   CHECK_EQ(col_params->instance.type, BROADCAST_COLLECTIVE);
79   CHECK_EQ(col_params->instance.impl_details.collective_name,
80            "HierarchicalTreeBroadcast");
81   const string& device_name =
82       col_params->instance.device_names[col_params->default_rank];
83   // Start by counting the devices in each task.
84   // Precondition: device_names must be sorted so that all devices in
85   // the same task are adjacent.
86   VLOG(2) << "Sorted task names: "
87           << str_util::Join(col_params->instance.task_names, ", ");
88   std::vector<int> dev_per_task;
89   const string* prior_task_name = &col_params->instance.task_names[0];
90   int dev_count = 1;
91   for (int di = 1; di < col_params->group.group_size; ++di) {
92     if (col_params->instance.task_names[di] != *prior_task_name) {
93       dev_per_task.push_back(dev_count);
94       dev_count = 1;
95       prior_task_name = &col_params->instance.task_names[di];
96     } else {
97       ++dev_count;
98     }
99   }
100   dev_per_task.push_back(dev_count);
101   CHECK_EQ(col_params->group.num_tasks, dev_per_task.size());
102 
103   if (VLOG_IS_ON(2)) {
104     string dpt_buf;
105     for (int dpt : dev_per_task) strings::StrAppend(&dpt_buf, dpt, ";");
106     VLOG(2) << "HierarchicalTreeBroadcaster::InitializeCollectiveParams device="
107             << device_name << " source_rank=" << col_params->source_rank
108             << " dev_per_task=" << dpt_buf;
109   }
110   int num_tasks = col_params->group.num_tasks;
111   // If there is just 1 task, then execute binary tree broadcast over all
112   // devices.  Otherwise, the first subdiv is inter-task broadcast, and then
113   // there are N more subdivs, where N is #task.
114   int num_subdivs = num_tasks + (num_tasks > 1 ? 1 : 0);
115   int total_num_devices = 0;
116   for (int num_dev : dev_per_task) total_num_devices += num_dev;
117 
118   col_params->instance.impl_details.subdiv_permutations.resize(num_subdivs);
119   col_params->subdiv_rank.reserve(num_subdivs);
120   col_params->instance.impl_details.subdiv_source_rank.reserve(num_subdivs);
121 
122   // Inter-task subdiv.  Pick one device from each task - this is the source
123   // device if it belongs to that task, or device 0 for that task.  If a device
124   // does not participate in the subdiv, set subdiv_rank to -1.
125   if (num_tasks > 1) {
126     const int sdi = 0;
127     std::vector<int>& perm =
128         col_params->instance.impl_details.subdiv_permutations[sdi];
129     CHECK_EQ(perm.size(), 0);
130     int device_count = 0;
131     int source_task = GetDeviceTask(col_params->source_rank, dev_per_task);
132     for (int ti = 0; ti < col_params->group.num_tasks; ti++) {
133       bool participate = false;
134       if (source_task == ti) {
135         // Source device belongs to this task.
136         perm.push_back(col_params->source_rank);
137         participate =
138             col_params->instance.device_names[col_params->source_rank] ==
139             device_name;
140       } else {
141         // Source does not belong to this task, choose dev 0.
142         perm.push_back(device_count);
143         participate =
144             col_params->instance.device_names[device_count] == device_name;
145       }
146       if (participate) col_params->subdiv_rank.push_back(ti);
147       device_count += dev_per_task[ti];
148     }
149     if (col_params->subdiv_rank.empty()) col_params->subdiv_rank.push_back(-1);
150     col_params->instance.impl_details.subdiv_source_rank.push_back(source_task);
151   }
152 
153   // Intra-task subdivs.  Pick all devices in task ti for subdiv sdi.  Set
154   // source to dev 0 for that task if it does not contain original source, else
155   // set to rank of original source.  If a device does not participate in
156   // the subdiv, set subdiv_rank to -1;
157   int abs_di = 0;
158   for (int ti = 0; ti < col_params->group.num_tasks; ti++) {
159     const int sdi = ti + (num_tasks > 1 ? 1 : 0);
160     std::vector<int>& perm =
161         col_params->instance.impl_details.subdiv_permutations[sdi];
162     CHECK_EQ(perm.size(), 0);
163     bool participate = false;
164     int subdiv_source = 0;
165     for (int di = 0; di < dev_per_task[ti]; di++) {
166       perm.push_back(abs_di);
167       if (col_params->instance.device_names[abs_di] == device_name) {
168         participate = true;
169         col_params->subdiv_rank.push_back(di);
170       }
171       if (abs_di == col_params->source_rank) subdiv_source = di;
172       abs_di++;
173     }
174     if (!participate) col_params->subdiv_rank.push_back(-1);
175     col_params->instance.impl_details.subdiv_source_rank.push_back(
176         subdiv_source);
177   }
178 
179   for (int sri = 0; sri < num_subdivs; sri++) {
180     CHECK_GE(col_params->instance.impl_details.subdiv_source_rank[sri], 0);
181   }
182 
183   VLOG(2) << collective_util::SubdivPermDebugString(*col_params);
184   return Status::OK();
185 }
186 
InitializeCollectiveContext(CollectiveContext * col_ctx)187 Status HierarchicalTreeBroadcaster::InitializeCollectiveContext(
188     CollectiveContext* col_ctx) {
189   CHECK(col_ctx->dev_mgr);
190   col_ctx_ = col_ctx;
191   col_params_ = &col_ctx->col_params;
192   return collective_util::InitializeDeviceAndLocality(
193       col_ctx->dev_mgr, col_ctx->device_name, &col_ctx->device,
194       &col_ctx->device_locality);
195 }
196 
Run(StatusCallback done)197 void HierarchicalTreeBroadcaster::Run(StatusCallback done) {
198   CHECK(col_ctx_);
199   CHECK(col_params_);
200   done_ = std::move(done);
201   is_source_ = col_params_->is_source;
202   RunTree();
203 }
204 
205 // Binary tree parent/child relations are trivial to calculate, i.e.
206 // device at rank r is the parent of 2r+1 and 2r+2.  The one exception
207 // is if the source is not rank 0.  We treat that case as though the
208 // source is appended to the front of the rank ordering as well as
209 // continuing to occupy its current position.  Hence we calculate as
210 // though each device's rank is actually r+1, then subtract 1 again to
211 // get the descendent ranks.  If the source is not rank 0 then its
212 // descendants include both {0,1} and the descendents of its current
213 // position.  Where a non-0-rank source is a descendent of another
214 // device, no send to it is necessary.
215 
216 /* static*/
TreeRecvFrom(const CollectiveParams & cp,int subdiv)217 int HierarchicalTreeBroadcaster::TreeRecvFrom(const CollectiveParams& cp,
218                                               int subdiv) {
219   DCHECK_LT(subdiv, static_cast<int>(cp.subdiv_rank.size()));
220   int my_rank = cp.subdiv_rank[subdiv];
221   if (-1 == my_rank) return -1;
222 
223   const auto& impl = cp.instance.impl_details;
224   DCHECK_LT(subdiv, static_cast<int>(impl.subdiv_source_rank.size()));
225   int source_rank = impl.subdiv_source_rank[subdiv];
226   if (my_rank == source_rank) return -1;
227   if (source_rank == 0) {
228     return (my_rank - 1) / 2;
229   } else {
230     int predecessor_rank = (my_rank / 2) - 1;
231     return (predecessor_rank < 0) ? source_rank : predecessor_rank;
232   }
233 }
234 
235 /* static */
TreeSendTo(const CollectiveParams & cp,int subdiv,std::vector<int> * targets)236 void HierarchicalTreeBroadcaster::TreeSendTo(const CollectiveParams& cp,
237                                              int subdiv,
238                                              std::vector<int>* targets) {
239   DCHECK_LT(subdiv, static_cast<int>(cp.subdiv_rank.size()));
240   int my_rank = cp.subdiv_rank[subdiv];
241   if (-1 == my_rank) return;
242 
243   const auto& impl = cp.instance.impl_details;
244   DCHECK_LT(subdiv, static_cast<int>(impl.subdiv_source_rank.size()));
245   int source_rank = impl.subdiv_source_rank[subdiv];
246 
247   int group_size = 0;
248   for (int i = 0; i < impl.subdiv_permutations[subdiv].size(); i++) {
249     if (impl.subdiv_permutations[subdiv][i] >= 0) {
250       group_size++;
251     }
252   }
253 
254   targets->clear();
255   int successor_rank = 0;
256   if (source_rank == 0) {
257     successor_rank = (2 * my_rank) + 1;
258   } else {
259     successor_rank = (2 * (my_rank + 1));
260   }
261   DCHECK_NE(successor_rank, my_rank);
262   if (cp.is_source && source_rank != 0) {
263     // The source sends to rank 0,1 in addition to its positional
264     // descendants.
265     if (group_size > 1) {
266       targets->push_back(0);
267     }
268     if (group_size > 2 && source_rank != 1) {
269       targets->push_back(1);
270     }
271   }
272   for (int i = 0; i < 2; ++i) {
273     if (successor_rank < group_size && successor_rank != source_rank) {
274       targets->push_back(successor_rank);
275     }
276     ++successor_rank;
277   }
278 }
279 
280 // Executes a hierarchical tree broadcast.
281 // Each subdiv is a broadcast between a subset of the devices.
282 // If there is only one task, there is one subdiv comprising a broadcast between
283 // all devices belonging to the task.
284 // If there are n tasks, n>1, then there are n+1 subdivs.  In the first (global)
285 // subdiv, one device from each task participates in a binary tree broadcast.
286 // Each task receives a copy of the tensor on one device via this broadcast.
287 // Subsequent subdivs correspond to intra-task broadcasts.  Subdiv i+1
288 // corresponds to broadcast between all devices on task i.  Thus, each task
289 // participates in at most 2 subdivs.
RunTree()290 void HierarchicalTreeBroadcaster::RunTree() {
291   int num_subdivs = static_cast<int>(col_params_->subdiv_rank.size());
292   // TODO(b/78352018): this is easily improved when a node participates in both
293   // first and second subdivision.  It would first send to its descendents in
294   // the first subdiv, then wait until all pending ops are finished before
295   // sending to descendents in second subdiv.  A better implementation would
296   // collapse the two send blocks.
297   for (int si = 0; si < num_subdivs; si++) {
298     int my_rank = col_params_->subdiv_rank[si];
299     // If rank is -1, this device does not participate in this subdiv.
300     if (-1 == my_rank) continue;
301     int source_rank = col_params_->instance.impl_details.subdiv_source_rank[si];
302     if (VLOG_IS_ON(1)) {
303       string subdiv_buf;
304       for (int r : col_params_->instance.impl_details.subdiv_permutations[si]) {
305         strings::StrAppend(&subdiv_buf, r, ",");
306       }
307       VLOG(1) << "Running Broadcast tree device=" << col_ctx_->device_name
308               << " subdiv=" << si << " perm=" << subdiv_buf
309               << " my_rank=" << my_rank << " source_rank=" << source_rank;
310     }
311 
312     mutex mu;               // also guards status_ while callbacks are pending
313     int pending_count = 0;  // GUARDED_BY(mu)
314     condition_variable all_done;
315 
316     if (my_rank >= 0 && my_rank != source_rank) {
317       // Begin by receiving the value.
318       int recv_from_rank = TreeRecvFrom(*col_params_, si);
319       Notification note;
320       DispatchRecv(si, recv_from_rank, my_rank, col_ctx_->output,
321                    [this, &mu, &note](const Status& s) {
322                      mutex_lock l(mu);
323                      status_.Update(s);
324                      note.Notify();
325                    });
326       note.WaitForNotification();
327     }
328 
329     // Then forward value to all descendent devices.
330     if (my_rank >= 0 && status_.ok()) {
331       std::vector<int> send_to_ranks;
332       TreeSendTo(*col_params_, si, &send_to_ranks);
333       for (int i = 0; i < send_to_ranks.size(); ++i) {
334         int target_rank = send_to_ranks[i];
335         {
336           mutex_lock l(mu);
337           ++pending_count;
338         }
339         DispatchSend(si, target_rank, my_rank,
340                      (is_source_ ? col_ctx_->input : col_ctx_->output),
341                      [this, &mu, &pending_count, &all_done](const Status& s) {
342                        mutex_lock l(mu);
343                        status_.Update(s);
344                        --pending_count;
345                        if (pending_count == 0) {
346                          all_done.notify_all();
347                        }
348                      });
349       }
350     }
351 
352     // For the original source device, we copy input to output if they are
353     // different.
354     // If there is only 1 subdiv, we do this in that subdiv.  If there is more
355     // than 1 subdiv, then the original source device will participate in 2
356     // subdivs - the global inter-task broadcast and one local intra-task
357     // broadcast.  In this case, we perform the copy in the second subdiv for
358     // this device.
359     if (status_.ok() && is_source_ && (1 == num_subdivs || 0 != si)) {
360       VLOG(2) << "copying input to output for device=" << col_ctx_->device_name
361               << " subdiv=" << si;
362       if (col_ctx_->input != col_ctx_->output &&
363           (DMAHelper::base(col_ctx_->input) !=
364            DMAHelper::base(col_ctx_->output))) {
365         {
366           mutex_lock l(mu);
367           ++pending_count;
368         }
369         DeviceContext* op_dev_ctx = col_ctx_->op_ctx->op_device_context();
370         CollectiveRemoteAccessLocal::MemCpyAsync(
371             op_dev_ctx, op_dev_ctx, col_ctx_->device, col_ctx_->device,
372             col_ctx_->op_ctx->input_alloc_attr(0),
373             col_ctx_->op_ctx->output_alloc_attr(0), col_ctx_->input,
374             col_ctx_->output, 0, /*stream_index*/
375             [this, &mu, &pending_count, &all_done](const Status& s) {
376               mutex_lock l(mu);
377               status_.Update(s);
378               --pending_count;
379               if (0 == pending_count) {
380                 all_done.notify_all();
381               }
382             });
383       }
384     }
385 
386     // Then wait for all pending actions to complete.
387     {
388       mutex_lock l(mu);
389       if (pending_count > 0) {
390         all_done.wait(l);
391       }
392     }
393   }
394   VLOG(2) << "device=" << col_ctx_->device_name << " return status " << status_;
395   done_(status_);
396 }
397 
DispatchSend(int subdiv,int dst_rank,int src_rank,const Tensor * src_tensor,const StatusCallback & done)398 void HierarchicalTreeBroadcaster::DispatchSend(int subdiv, int dst_rank,
399                                                int src_rank,
400                                                const Tensor* src_tensor,
401                                                const StatusCallback& done) {
402   string send_buf_key =
403       BroadcastBufKey(col_ctx_->exec_key, subdiv, src_rank, dst_rank);
404   int dst_idx =
405       col_params_->instance.impl_details.subdiv_permutations[subdiv][dst_rank];
406   VLOG(3) << "DispatchSend " << send_buf_key << " from_device "
407           << col_ctx_->device_name << " to_device "
408           << col_params_->instance.device_names[dst_idx] << " subdiv=" << subdiv
409           << " dst_rank=" << dst_rank << " dst_idx=" << dst_idx;
410   col_ctx_->col_exec->PostToPeer(col_params_->instance.device_names[dst_idx],
411                                  col_params_->instance.task_names[dst_idx],
412                                  send_buf_key, col_ctx_->device,
413                                  col_ctx_->op_ctx->op_device_context(),
414                                  col_ctx_->op_ctx->output_alloc_attr(0),
415                                  src_tensor, col_ctx_->device_locality, done);
416 }
417 
DispatchRecv(int subdiv,int src_rank,int dst_rank,Tensor * dst_tensor,const StatusCallback & done)418 void HierarchicalTreeBroadcaster::DispatchRecv(int subdiv, int src_rank,
419                                                int dst_rank, Tensor* dst_tensor,
420                                                const StatusCallback& done) {
421   string recv_buf_key =
422       BroadcastBufKey(col_ctx_->exec_key, subdiv, src_rank, dst_rank);
423   int src_idx =
424       col_params_->instance.impl_details.subdiv_permutations[subdiv][src_rank];
425   VLOG(3) << "DispatchRecv " << recv_buf_key << " from_device "
426           << col_params_->instance.device_names[src_idx] << " to_device "
427           << col_ctx_->device_name << " subdiv=" << subdiv
428           << " src_rank=" << src_rank << " src_idx=" << src_idx;
429   col_ctx_->col_exec->RecvFromPeer(
430       col_params_->instance.device_names[src_idx],
431       col_params_->instance.task_names[src_idx],
432       col_params_->task.is_local[src_idx], recv_buf_key, col_ctx_->device,
433       col_ctx_->op_ctx->op_device_context(),
434       col_ctx_->op_ctx->output_alloc_attr(0), dst_tensor,
435       col_ctx_->device_locality, 0 /*stream_index*/, done);
436 }
437 
438 REGISTER_COLLECTIVE(HierarchicalTreeBroadcast, HierarchicalTreeBroadcaster);
439 
440 }  // namespace tensorflow
441