• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifdef TENSORFLOW_USE_MPI
17 
18 #include <queue>
19 #include <thread>
20 #include <unordered_map>
21 
22 #include "tensorflow/core/framework/op.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/types.pb.h"
25 #include "tensorflow/core/platform/mutex.h"
26 
27 #define EIGEN_USE_THREADS
28 
29 #if GOOGLE_CUDA
30 #include <cuda_runtime.h>
31 #include "tensorflow/stream_executor/stream.h"
32 #endif
33 
34 #include "tensorflow/stream_executor/lib/statusor.h"
35 
36 #define OMPI_SKIP_MPICXX
37 #include "third_party/mpi/mpi.h"
38 #include "tensorflow/contrib/mpi_collectives/kernels/ring.h"
39 #include "tensorflow/contrib/mpi_collectives/mpi_message.pb.h"
40 
41 /*
42  * MPI Allreduce and Allgather Ops for TensorFlow.
43  *
44  * TensorFlow natively provides inter-device communication through send and
45  * receive ops and inter-node communication through Distributed TensorFlow,
46  * based on the same send and receive abstractions. These end up being
47  * insufficient for synchronous data-parallel training on HPC clusters where
48  * Infiniband or other high-speed interconnects are available.  This module
49  * implements MPI ops for allgather and allreduce, which do bandwidth-optimal
50  * gathers and reductions and can take advantage of hardware-optimized
51  * communication libraries through the MPI implementation.
52  *
53  * The primary logic of the allreduce and allgather are in RingAllgather() and
54  * RingAllreduce(). The background thread which facilitates MPI operations is
55  * run in BackgroundThreadLoop(). The provided MPI ops are:
56  *      – MPIInit:
57  *          Initialize MPI on a given device (CPU or GPU).
58  *          Should only be run on a single device in every process.
59  *      – MPISize:
60  *          Get the number of MPI processes in the global communicator.
61  *      – MPIRank:
62  *          Get the rank of the current MPI process in the global communicator.
63  *      – MPILocalRank:
64  *          Get the local rank of the current MPI process within its node.
65  *      – MPIAllreduce:
66  *          Perform an allreduce on a Tensor, returning the sum
67  *          across all MPI processes in the global communicator.
68  *      – MPIAllgather:
69  *          Perform an allgather on a Tensor, returning the concatenation of
70  *          the tensor on the first dimension across all MPI processes in the
71  *          global communicator.
72  *
73  */
74 
75 template <class T>
76 using StatusOr = stream_executor::port::StatusOr<T>;
77 
78 using CPUDevice = Eigen::ThreadPoolDevice;
79 using GPUDevice = Eigen::GpuDevice;
80 
81 namespace tensorflow {
82 namespace contrib {
83 namespace mpi_collectives {
84 
85 // Make sure template specializations are generated in the ring.cu.cc and the
86 // ring.cc file, not in this file.
87 extern template Status RingAllreduce<GPUDevice, int>(OpKernelContext*,
88                                                      const Tensor*, Tensor*,
89                                                      Tensor*);
90 extern template Status RingAllreduce<GPUDevice, long long>(OpKernelContext*,
91                                                            const Tensor*,
92                                                            Tensor*, Tensor*);
93 extern template Status RingAllreduce<GPUDevice, float>(OpKernelContext*,
94                                                        const Tensor*, Tensor*,
95                                                        Tensor*);
96 extern template Status RingAllgather<GPUDevice, int>(OpKernelContext*,
97                                                      const Tensor*,
98                                                      const std::vector<size_t>&,
99                                                      Tensor*);
100 extern template Status RingAllgather<GPUDevice, long long>(
101     OpKernelContext*, const Tensor*, const std::vector<size_t>&, Tensor*);
102 extern template Status RingAllgather<GPUDevice, float>(
103     OpKernelContext*, const Tensor*, const std::vector<size_t>&, Tensor*);
104 extern template Status RingAllreduce<CPUDevice, int>(OpKernelContext*,
105                                                      const Tensor*, Tensor*,
106                                                      Tensor*);
107 extern template Status RingAllreduce<CPUDevice, long long>(OpKernelContext*,
108                                                            const Tensor*,
109                                                            Tensor*, Tensor*);
110 extern template Status RingAllreduce<CPUDevice, float>(OpKernelContext*,
111                                                        const Tensor*, Tensor*,
112                                                        Tensor*);
113 extern template Status RingAllgather<CPUDevice, int>(OpKernelContext*,
114                                                      const Tensor*,
115                                                      const std::vector<size_t>&,
116                                                      Tensor*);
117 extern template Status RingAllgather<CPUDevice, long long>(
118     OpKernelContext*, const Tensor*, const std::vector<size_t>&, Tensor*);
119 extern template Status RingAllgather<CPUDevice, float>(
120     OpKernelContext*, const Tensor*, const std::vector<size_t>&, Tensor*);
121 
122 namespace {
123 
124 // Return true if the templated type is GPUDevice, otherwise false.
125 template <typename T>
126 bool IsGPUDevice();
127 template <>
IsGPUDevice()128 bool IsGPUDevice<GPUDevice>() {
129   return true;
130 };
131 template <>
IsGPUDevice()132 bool IsGPUDevice<CPUDevice>() {
133   return false;
134 };
135 
136 // A callback to call after the MPI communication completes. Since the
137 // allreduce and allgather ops are asynchronous, this callback is what resumes
138 // computation after the reduction is completed.
139 typedef std::function<void(StatusOr<Tensor>)> CommunicationDoneCallback;
140 
141 struct CollectiveOpRecord {
142   // The rank performing this piece of the op
143   int rank;
144 
145   // The name of the op/tensor to be reduced
146   std::string name;
147 
148   // The op's kernel context
149   OpKernelContext* context;
150 
151   // Data type of the op
152   DataType dtype;
153 
154   // The input tensor
155   const Tensor* in_t;
156 
157   // Allgather: Vector of per-rank first-dimension sizes
158   std::vector<size_t> sizes_vec;
159 
160   // The temp tensor for intermediate results
161   Tensor temp_t;
162 
163   // The output tensor
164   Tensor* out_t;
165 
166   // Whether to run this op on the gpu
167   bool on_gpu;
168 
169   // The callback to call after the op has completed
170   CommunicationDoneCallback callback;
171 };
172 
173 // Table storing Tensors to be reduced, keyed by unique name.
174 // This table contains everything necessary to do the reduction
175 typedef std::unordered_map<std::string, CollectiveOpRecord> TensorTable;
176 
177 // Table for storing Tensor metadata on rank zero. This is used for error
178 // checking and size calculations, as well as determining when a reduction is
179 // ready to be done (when all nodes are ready to do it).
180 typedef std::unordered_map<std::string, std::vector<MPIRequest> > MessageTable;
181 
182 // The global state required for the MPI ops.
183 //
184 // MPI is a library that stores a lot of global per-program state and often
185 // requires running on a single thread. As a result, we have to have a single
186 // background thread responsible for all MPI operations, and communicate with
187 // that background thread through global state.
188 struct MPIGlobalState {
189   // An atomic boolean which is set to true when MPI is initialized.
190   // This ensures that MPI_Init is never called twice.
191   std::atomic_flag initialized_flag = ATOMIC_FLAG_INIT;
192 
193   // Condition variable to wait for initialization
194   condition_variable cv;
195 
196   // Whether MPI_Init has been completed on the background thread.
197   bool initialization_done = false;
198 
199   // Whether MPI_Init succeeded on the background thread.
200   Status init_status;
201 
202   // A mutex that needs to be used whenever MPI operations touch
203   // shared structures.
204   mutex mu;
205 
206   // Tensors waiting to be allreduced or allgathered.
207   TensorTable tensor_table;
208 
209   // Queue of MPI requests waiting to be sent to the coordinator node.
210   std::queue<MPIRequest> message_queue;
211 
212   // Background thread running MPI communication.
213   std::thread background_thread;
214 
215   // Whether the background thread should shutdown.
216   bool shut_down = false;
217 
218   // Only exists on the coordinator node (rank zero). Maintains a count of
219   // how many nodes are ready to allreduce every tensor (keyed by tensor
220   // name).
221   std::unique_ptr<MessageTable> message_table;
222 
223   // The MPI rank, local rank, and size.
224   int rank = 0;
225   int local_rank = 0;
226   int size = 1;
227 
228   // The device that MPI was initialized on. (-1 for no GPU)
229   int device = -1;
230 
231   // The CUDA stream used for data transfers and within-allreduce operations.
232   // A naive implementation would use the TensorFlow StreamExecutor CUDA
233   // stream. However, the allreduce and allgather require doing memory copies
234   // and kernel executions (for accumulation of values on the GPU). However,
235   // the subsequent operations must wait for those operations to complete,
236   // otherwise MPI (which uses its own stream internally) will begin the data
237   // transfers before the CUDA calls are complete. In order to wait for those
238   // CUDA operations, if we were using the TensorFlow stream, we would have
239   // to synchronize that stream; however, other TensorFlow threads may be
240   // submitting more work to that stream, so synchronizing on it can cause
241   // the allreduce to be delayed, waiting for compute totally unrelated to it
242   // in other parts of the graph. Overlaying memory transfers and compute
243   // during backpropagation is crucial for good performance, so we cannot use
244   // the TensorFlow stream, and must use our own stream.
245 #if GOOGLE_CUDA
246   cudaStream_t stream;
247   std::atomic_flag stream_created_flag = ATOMIC_FLAG_INIT;
248 #endif
249 
~MPIGlobalStatetensorflow::contrib::mpi_collectives::__anon63380d9d0111::MPIGlobalState250   ~MPIGlobalState() {
251     // Make sure that the destructor of the background thread is safe to
252     // call. If a thread is still joinable (not detached or complete) its
253     // destructor cannot be called.
254     if (background_thread.joinable()) {
255       shut_down = true;
256       background_thread.join();
257     }
258   }
259 };
260 
261 // All the MPI state that must be stored globally per-process.
262 static MPIGlobalState mpi_global;
263 
264 // For clarify in argument lists.
265 #define RANK_ZERO 0
266 
267 // A tag used for all coordinator messaging.
268 #define TAG_NOTIFY 1
269 
270 // Store the MPIRequest for a name, and return whether the total count of
271 // MPIRequests for that tensor is now equal to the MPI size (and thus we are
272 // ready to reduce the tensor).
IncrementTensorCount(std::unique_ptr<MessageTable> & message_table,MPIRequest msg,int mpi_size)273 bool IncrementTensorCount(std::unique_ptr<MessageTable>& message_table,
274                           MPIRequest msg, int mpi_size) {
275   auto name = msg.tensor_name();
276   auto table_iter = message_table->find(name);
277   if (table_iter == message_table->end()) {
278     message_table->emplace(name, std::vector<MPIRequest>({msg}));
279     table_iter = message_table->find(name);
280   } else {
281     table_iter->second.push_back(msg);
282   }
283 
284   int count = table_iter->second.size();
285   return count == mpi_size;
286 }
287 
288 // Once a tensor is ready to be reduced, the coordinator sends an MPIResponse
289 // instructing all ranks to start the reduction to all ranks. The MPIResponse
290 // also contains error messages in case the submitted MPIRequests were not
291 // valid (for example, contained mismatched shapes or types).
292 //
293 // Constructing the MPIResponse, thus, requires a whole lot of error checking.
ConstructMPIResponse(std::unique_ptr<MessageTable> & message_table,std::string name)294 MPIResponse ConstructMPIResponse(std::unique_ptr<MessageTable>& message_table,
295                                  std::string name) {
296   bool error = false;
297   auto it = message_table->find(name);
298   assert(it != message_table->end());
299 
300   std::vector<MPIRequest> requests = it->second;
301   assert(requests.size() > 0);
302 
303   std::ostringstream error_message_stream;
304 
305   // Check that all data types being reduced or gathered are identical
306   auto data_type = requests[0].tensor_type();
307   for (unsigned int i = 1; i < requests.size(); i++) {
308     auto request_type = requests[i].tensor_type();
309     if (data_type != request_type) {
310       error = true;
311       error_message_stream << "Mismatched data types: One rank had type "
312                            << DataType_Name(data_type)
313                            << ", but another rank had type "
314                            << DataType_Name(request_type) << ".";
315       break;
316     }
317   }
318 
319   // Check that all requested operations are the same
320   auto message_type = requests[0].request_type();
321   for (unsigned int i = 1; i < requests.size(); i++) {
322     if (error) {
323       break;
324     }
325 
326     auto request_type = requests[i].request_type();
327     if (message_type != request_type) {
328       error = true;
329       error_message_stream << "Mismatched MPI operations: One rank did an "
330                            << message_type << ", but another rank did an "
331                            << request_type << ".";
332       break;
333     }
334   }
335 
336   // If we are doing an allreduce, check that all tensor shapes
337   // are identical
338   if (message_type == MPIRequest::ALLREDUCE) {
339     TensorShape tensor_shape = requests[0].tensor_shape();
340     for (unsigned int i = 1; i < requests.size(); i++) {
341       if (error) {
342         break;
343       }
344 
345       TensorShape request_shape = requests[i].tensor_shape();
346       if (tensor_shape != request_shape) {
347         error = true;
348         error_message_stream << "Mismatched allreduce tensor shapes: "
349                              << "One rank reduced a tensor of shape "
350                              << tensor_shape.DebugString()
351                              << ", but another rank sent a tensor of shape "
352                              << request_shape.DebugString() << ".";
353         break;
354       }
355     }
356   }
357 
358   // If we are doing an allgather, make sure all but the first dimension are
359   // the same. The first dimension may be different and the output tensor is
360   // the sum of the first dimension. Collect the sizes by rank.
361   if (message_type == MPIRequest::ALLGATHER) {
362     TensorShape tensor_shape = requests[0].tensor_shape();
363 
364     if (tensor_shape.dims() == 0) {
365       error = true;
366       error_message_stream << "Rank zero tried to gather a rank-zero tensor.";
367     }
368 
369     for (unsigned int i = 1; i < requests.size(); i++) {
370       if (error) {
371         break;
372       }
373 
374       TensorShape request_shape = requests[i].tensor_shape();
375       if (tensor_shape.dims() != request_shape.dims()) {
376         error = true;
377         error_message_stream << "Mismatched allgather tensor shapes: "
378                              << "One rank gathered a tensor of rank "
379                              << tensor_shape.dims()
380                              << ", but another rank sent a tensor of rank "
381                              << request_shape.dims() << ".";
382         break;
383       }
384 
385       for (unsigned int dim = 1; dim < tensor_shape.dims(); dim++) {
386         if (tensor_shape.dim_size(dim) != request_shape.dim_size(dim)) {
387           error = true;
388           error_message_stream
389               << "Mismatched allgather tensor shapes: "
390               << "One rank gathered a tensor with dimension " << dim
391               << " equal to " << tensor_shape.dim_size(dim)
392               << ", but another rank sent a tensor with dimension " << dim
393               << " equal to " << request_shape.dim_size(dim) << ".";
394           break;
395         }
396       }
397     }
398   }
399 
400   MPIResponse response;
401   response.set_tensor_name(name);
402   if (error) {
403     std::string error_message = error_message_stream.str();
404     response.set_response_type(MPIResponse::ERROR);
405     response.set_error_message(error_message);
406   } else {
407     auto response_type = MPIResponse::ERROR;
408     if (message_type == MPIRequest::ALLREDUCE) {
409       response_type = MPIResponse::ALLREDUCE;
410     } else {
411       response_type = MPIResponse::ALLGATHER;
412     }
413     response.set_response_type(response_type);
414   }
415 
416   // Clear all queued up requests for this name. They are now taken care of
417   // by the constructed MPI response.
418   message_table->erase(it);
419 
420   return response;
421 }
422 
423 // Process an MPIResponse by doing a reduction, a gather, or raising an error.
PerformCollectiveOp(TensorTable & tensor_table,MPIResponse response)424 void PerformCollectiveOp(TensorTable& tensor_table, MPIResponse response) {
425   OpKernelContext* context;
426   const Tensor* input_tensor;
427   std::vector<size_t> sizes_vec;
428   Tensor temp_tensor;
429   Tensor* output_tensor;
430   CommunicationDoneCallback callback;
431   bool on_gpu;
432   {
433     // Lock on the tensor table.
434     mutex_lock guard(mpi_global.mu);
435 
436     // We should never fail at finding this key in the tensor table.
437     auto name = response.tensor_name();
438     auto iter = tensor_table.find(name);
439     assert(iter != tensor_table.end());
440 
441     assert(response.response_type() == MPIResponse::ALLREDUCE ||
442            response.response_type() == MPIResponse::ALLGATHER ||
443            response.response_type() == MPIResponse::ERROR);
444 
445     CollectiveOpRecord record = iter->second;
446     context = record.context;
447     input_tensor = record.in_t;
448     sizes_vec = record.sizes_vec;
449     temp_tensor = record.temp_t;
450     output_tensor = record.out_t;
451     on_gpu = record.on_gpu;
452     callback = record.callback;
453 
454     // Clear the tensor table of this tensor and its callbacks; the rest of
455     // this function takes care of it.
456     tensor_table.erase(iter);
457   }
458 
459   // Use CPUDevice instead of GPUDevice if no CUDA, to ensure we don't
460   // link to non-existent symbols.
461 #if GOOGLE_CUDA
462 #define GPU_DEVICE_IF_CUDA GPUDevice
463 #else
464 #define GPU_DEVICE_IF_CUDA CPUDevice
465 #endif
466 
467   Status status;
468   auto dtype = input_tensor->dtype();
469   if (response.response_type() == MPIResponse::ALLGATHER) {
470     if (dtype == DT_FLOAT) {
471       status = on_gpu ? RingAllgather<GPU_DEVICE_IF_CUDA, float>(
472                             context, input_tensor, sizes_vec, output_tensor)
473                       : RingAllgather<CPUDevice, float>(
474                             context, input_tensor, sizes_vec, output_tensor);
475     } else if (dtype == DT_INT32) {
476       status = on_gpu ? RingAllgather<GPU_DEVICE_IF_CUDA, int>(
477                             context, input_tensor, sizes_vec, output_tensor)
478                       : RingAllgather<CPUDevice, int>(context, input_tensor,
479                                                       sizes_vec, output_tensor);
480     } else if (dtype == DT_INT64) {
481       status = on_gpu ? RingAllgather<GPU_DEVICE_IF_CUDA, long long>(
482                             context, input_tensor, sizes_vec, output_tensor)
483                       : RingAllgather<CPUDevice, long long>(
484                             context, input_tensor, sizes_vec, output_tensor);
485     } else {
486       status = errors::Unknown("Invalid tensor type for MPI allgather.");
487     }
488   } else if (response.response_type() == MPIResponse::ALLREDUCE) {
489     if (dtype == DT_FLOAT) {
490       status = on_gpu ? RingAllreduce<GPU_DEVICE_IF_CUDA, float>(
491                             context, input_tensor, &temp_tensor, output_tensor)
492                       : RingAllreduce<CPUDevice, float>(
493                             context, input_tensor, &temp_tensor, output_tensor);
494     } else if (dtype == DT_INT32) {
495       status = on_gpu ? RingAllreduce<GPU_DEVICE_IF_CUDA, int>(
496                             context, input_tensor, &temp_tensor, output_tensor)
497                       : RingAllreduce<CPUDevice, int>(
498                             context, input_tensor, &temp_tensor, output_tensor);
499     } else if (dtype == DT_INT64) {
500       status = on_gpu ? RingAllreduce<GPU_DEVICE_IF_CUDA, long long>(
501                             context, input_tensor, &temp_tensor, output_tensor)
502                       : RingAllreduce<CPUDevice, long long>(
503                             context, input_tensor, &temp_tensor, output_tensor);
504     } else {
505       status = errors::Unknown("Invalid tensor type for MPI allreduce.");
506     }
507   } else if (response.response_type() == MPIResponse::ERROR) {
508     status = errors::FailedPrecondition(response.error_message());
509   }
510 
511   if (status.ok()) {
512     callback(StatusOr<Tensor>(*output_tensor));
513   } else {
514     callback(StatusOr<Tensor>(status));
515   }
516 }
517 
518 // The MPI background thread loop coordinates all the MPI processes and the
519 // tensor reductions. The design of the communicator mechanism is limited by a
520 // few considerations:
521 //
522 //      1. Some MPI implementations require all MPI calls to happen from a
523 //      single thread. Since TensorFlow may use several threads for graph
524 //      processing, this means we must have our own dedicated thread for
525 //      dealing with MPI.
526 //      2. We want to gracefully handle errors, when MPI processes do not
527 //      properly agree upon what should happen (such as mismatched types or
528 //      shapes). To do so requires the MPI processes to know about the shapes
529 //      and types of the relevant tensors on the other processes.
530 //      3. The MPI reductions and gathers should be able to happen in parallel
531 //      with other ongoing operations. Since MPI uses an internal
532 //      (inaccessible) GPU stream separate from the TF GPUDevice streams, we
533 //      cannot explicitly synchronize memcpys or kernels with it. As a result,
534 //      MPIAllreduce and MPIAllgather must be AsyncOpKernels to ensure proper
535 //      ordering of memcpys and kernels with respect to TF streams.
536 //      4. NOTE: We cannot guarantee that all the MPI processes reduce their
537 //      tensors in the same order. Thus, there must be a way to ensure the
538 //      reduction memcpys and kernels occur for correct tensors across all
539 //      ranks at the same time. We choose to use a coordinator (rank ID 0) to
540 //      gather and trigger the reduction operations that are ready to execute.
541 //
542 // The coordinator currently follows a master-worker paradigm. Rank zero acts
543 // as the master (the "coordinator"), whereas all other ranks are simply
544 // workers. Each rank runs its own background thread which progresses in ticks.
545 // In each tick, the following actions happen:
546 //
547 //      a) The workers send any available MPIRequests to the coordinator. These
548 //      MPIRequests indicate what the worker would like to do (i.e. which
549 //      tensor they would like to gather or reduce, as well as their shape and
550 //      type). They repeat this for every tensor that they would like to
551 //      operate on after that tensor's collective op has executed ComputeAsync.
552 //
553 //      b) The workers send an empty "DONE" message to the coordinator to
554 //      indicate that there are no more tensors they wish to operate on.
555 //
556 //      c) The coordinator receives the MPIRequests from the workers, as well
557 //      as from its own TensorFlow ops, and stores them in a request table. The
558 //      coordinator continues to receive MPIRequest messages until it has
559 //      received MPI_SIZE number of empty "DONE" messages.
560 //
561 //      d) The coordinator finds all tensors that are ready to be reduced,
562 //      gathered, or all operations that result in an error. For each of those,
563 //      it sends an MPIResponse to all the workers. When no more MPIResponses
564 //      are available, it sends a "DONE" response to the workers. If the
565 //      process is being shutdown, it instead sends a "SHUTDOWN" response.
566 //
567 //      e) The workers listen for MPIResponse messages, processing each one by
568 //      doing the required reduce or gather, until they receive a "DONE"
569 //      response from the coordinator. At that point, the tick ends.
570 //      If instead of "DONE" they receive "SHUTDOWN", they exit their
571 //      background loop.
572 // TODO: Use the global mpi_global state variable instead of a local one
BackgroundThreadLoop()573 void BackgroundThreadLoop() {
574 #if GOOGLE_CUDA
575   // Set the device, so that this thread uses the same GPU context as the
576   // calling thread.
577   // TODO: Ensure that this is operating correctly. The background thread
578   // needs to be able to control all GPUs that the rank has access to, and
579   // might be more than 1 GPU. Tensors could be resident in any of the
580   // GPUs, so the background thread's accumulate and copy kernels might need
581   // to correctly set the device and it might be necessary for the background
582   // thread to manage multiple streams.
583   cudaSetDevice(mpi_global.device);
584   cudaStreamCreate(&mpi_global.stream);
585 #endif
586 
587   // Initialize MPI. This must happen on the background thread, since not all
588   // MPI implementations support being called from multiple threads.
589   auto init_result = MPI_Init(NULL, NULL);
590   if (init_result != MPI_SUCCESS) {
591     mpi_global.init_status =
592         errors::Unknown("Could not initialize MPI; MPI_Init() failed.");
593     mpi_global.initialization_done = true;
594     mpi_global.cv.notify_all();
595     return;
596   } else {
597     mpi_global.init_status = Status::OK();
598   }
599 
600   // Get MPI rank to determine if we are rank zero.
601   int rank;
602   MPI_Comm_rank(MPI_COMM_WORLD, &rank);
603   bool is_coordinator = rank == 0;
604 
605   // Get MPI size to determine how many tensors to wait for before reducing.
606   int size;
607   MPI_Comm_size(MPI_COMM_WORLD, &size);
608 
609   // Determine local rank by querying the local communicator.
610   MPI_Comm local_comm;
611   MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL,
612                       &local_comm);
613   int local_rank;
614   MPI_Comm_rank(local_comm, &local_rank);
615 
616   mpi_global.rank = rank;
617   mpi_global.local_rank = local_rank;
618   mpi_global.size = size;
619   mpi_global.initialization_done = true;
620 
621   // Notify calling thread that initialization is complete
622   mpi_global.cv.notify_all();
623 
624   // TODO: MOVE MESSAGE TABLE INITIALIZATION TO LIBRARY LOAD!
625   // Initialize the tensor count table. No tensors are available yet.
626   if (is_coordinator) {
627     mpi_global.message_table =
628         std::unique_ptr<MessageTable>(new MessageTable());
629   }
630 
631   // The coordinator sends a SHUTDOWN message to trigger shutdown.
632   bool should_shut_down = false;
633   do {
634     // TODO: Eliminate the need for thread sleep by making all activity
635     // depend on other activity (e.g. condition or MPI waits).
636     std::this_thread::sleep_for(std::chrono::milliseconds(1));
637 
638     // Copy the data structures from global state under this lock.
639     // However, don't keep the lock for the rest of the loop, so that
640     // enqueued stream callbacks can continue.
641     std::queue<MPIRequest> message_queue;
642     {
643       mutex_lock guard(mpi_global.mu);
644       while (!mpi_global.message_queue.empty()) {
645         MPIRequest message = mpi_global.message_queue.front();
646         mpi_global.message_queue.pop();
647         message_queue.push(message);
648       }
649     }
650 
651     // Collect all tensors that are ready to be reduced. Record them in the
652     // tensor count table (rank zero) or send them to rank zero to be
653     // recorded (everyone else).
654     std::vector<std::string> ready_to_reduce;
655     while (!message_queue.empty()) {
656       // Pop the first available message message
657       MPIRequest message = message_queue.front();
658       message_queue.pop();
659 
660       if (is_coordinator) {
661         bool reduce =
662             IncrementTensorCount(mpi_global.message_table, message, size);
663         if (reduce) {
664           ready_to_reduce.push_back(message.tensor_name());
665         }
666       } else {
667         std::string encoded_message;
668         message.SerializeToString(&encoded_message);
669         MPI_Send(encoded_message.c_str(), encoded_message.length() + 1,
670                  MPI_BYTE, RANK_ZERO, TAG_NOTIFY, MPI_COMM_WORLD);
671       }
672     }
673 
674     // Rank zero has put all its own tensors in the tensor count table.
675     // Now, it should count all the tensors that are coming from other
676     // ranks at this tick. It should keep getting tensors until it gets a
677     // DONE message from all the other ranks.
678     if (is_coordinator) {
679       // Count of DONE messages. Keep receiving messages until the number
680       // of messages is equal to the number of processes. Initialize to
681       // one since the coordinator is effectively done.
682       int completed_ranks = 1;
683       while (completed_ranks != size) {
684         MPI_Status status;
685         MPI_Probe(MPI_ANY_SOURCE, TAG_NOTIFY, MPI_COMM_WORLD, &status);
686 
687         // Find number of characters in message (including zero byte).
688         int source_rank = status.MPI_SOURCE;
689         int msg_length;
690         MPI_Get_count(&status, MPI_BYTE, &msg_length);
691 
692         // If the length is zero, this is a DONE message.
693         if (msg_length == 0) {
694           completed_ranks++;
695           MPI_Recv(NULL, 0, MPI_BYTE, source_rank, TAG_NOTIFY, MPI_COMM_WORLD,
696                    &status);
697           continue;
698         }
699 
700         // Get tensor name from MPI into an std::string.
701         char* buffer = new char[msg_length];
702         MPI_Recv(buffer, msg_length, MPI_BYTE, source_rank, TAG_NOTIFY,
703                  MPI_COMM_WORLD, &status);
704         std::string received_data(buffer);
705         delete[] buffer;
706 
707         MPIRequest received_message;
708         received_message.ParseFromString(received_data);
709         auto received_name = received_message.tensor_name();
710 
711         bool reduce = IncrementTensorCount(mpi_global.message_table,
712                                            received_message, size);
713         if (reduce) {
714           ready_to_reduce.push_back(received_name);
715         }
716       }
717 
718       // At this point, rank zero should have a fully updated tensor
719       // count table and should know all the tensors that need to be
720       // reduced or gathered, and everyone else should have sent all
721       // their information to rank zero. We can now do reductions and
722       // gathers; rank zero will choose which ones and in what order,
723       // and will notify the other ranks before doing each reduction.
724       for (int i = 0; i < ready_to_reduce.size(); i++) {
725         // Notify all nodes which tensor we'd like to reduce now
726         auto name = ready_to_reduce[i];
727         MPIResponse response =
728             ConstructMPIResponse(mpi_global.message_table, name);
729 
730         std::string encoded_response;
731         response.SerializeToString(&encoded_response);
732         for (int r = 1; r < size; r++) {
733           MPI_Send(encoded_response.c_str(), encoded_response.length() + 1,
734                    MPI_BYTE, r, TAG_NOTIFY, MPI_COMM_WORLD);
735         }
736 
737         // Perform the reduction. All nodes should end up performing
738         // the same reduction.
739         PerformCollectiveOp(mpi_global.tensor_table, response);
740       }
741 
742       // Notify all nodes that we are done with the reductions for this
743       // tick.
744       MPIResponse done_response;
745       should_shut_down = mpi_global.shut_down;
746       done_response.set_response_type(
747           mpi_global.shut_down ? MPIResponse::SHUTDOWN : MPIResponse::DONE);
748       std::string encoded_response;
749       done_response.SerializeToString(&encoded_response);
750       for (int r = 1; r < size; r++) {
751         MPI_Send(encoded_response.c_str(), encoded_response.length() + 1,
752                  MPI_BYTE, r, TAG_NOTIFY, MPI_COMM_WORLD);
753       }
754     } else {
755       // Notify the coordinator that this node is done sending messages.
756       // A DONE message is encoded as a zero-length message.
757       MPI_Send(NULL, 0, MPI_BYTE, RANK_ZERO, TAG_NOTIFY, MPI_COMM_WORLD);
758 
759       // Receive names for tensors to reduce from rank zero. Once we
760       // receive a empty DONE message, stop waiting for more names.
761       while (true) {
762         MPI_Status status;
763         MPI_Probe(0, TAG_NOTIFY, MPI_COMM_WORLD, &status);
764 
765         // Find number of characters in message (including zero byte).
766         int msg_length;
767         MPI_Get_count(&status, MPI_BYTE, &msg_length);
768 
769         // Get tensor name from MPI into an std::string.
770         char* buffer = new char[msg_length];
771         MPI_Recv(buffer, msg_length, MPI_BYTE, 0, TAG_NOTIFY, MPI_COMM_WORLD,
772                  &status);
773         std::string received_message(buffer);
774         delete[] buffer;
775 
776         MPIResponse response;
777         response.ParseFromString(received_message);
778         if (response.response_type() == MPIResponse::DONE) {
779           // No more messages this tick
780           break;
781         } else if (response.response_type() == MPIResponse::SHUTDOWN) {
782           // No more messages this tick, and the background thread
783           // should shut down
784           should_shut_down = true;
785           break;
786         } else {
787           // Process the current message
788           PerformCollectiveOp(mpi_global.tensor_table, response);
789         }
790       }
791     }
792   } while (!should_shut_down);
793 
794   MPI_Finalize();
795 }
796 
797 // Initialize MPI and start the MPI background thread. Ensure that this is
798 // only done once no matter how many times this function is called.
InitializeMPIOnce(bool gpu)799 Status InitializeMPIOnce(bool gpu) {
800   // Ensure MPI is only initialized once.
801   if (mpi_global.initialized_flag.test_and_set()) return mpi_global.init_status;
802 
803   mpi_global.device = -1;
804 #if GOOGLE_CUDA
805   if (gpu) {
806     cudaGetDevice(&mpi_global.device);
807   }
808 #endif
809 
810   // Start the MPI background thread, which assumes MPI is initialized
811   // TODO: Change this to a Tensorflow thread
812   mpi_global.background_thread = std::thread(BackgroundThreadLoop);
813 
814   // Wait to ensure that the background thread has finished initializing MPI
815   mutex_lock guard(mpi_global.mu);
816   mpi_global.cv.wait(guard);
817   if (!mpi_global.initialization_done) {
818     mpi_global.init_status =
819         errors::Unknown("Failed to wait for MPI initialization.");
820   }
821 
822   return mpi_global.init_status;
823 }
824 
825 // Check that MPI is initialized.
IsMPIInitialized()826 Status IsMPIInitialized() {
827   if (!mpi_global.initialization_done) {
828     return errors::FailedPrecondition(
829         "MPI has not been initialized; use tf.contrib.mpi.Session.");
830   }
831   return Status::OK();
832 }
833 
834 // This function (called from the callback set up in MPIAll*Op::ComputeAsync)
835 // only adds the op's record into the local op queue (to track the op's
836 // progress), and sends a message to the coordinator indicating that this rank
837 // is ready to begin. The MPI background thread will handle the MPI message.
EnqueueTensorCollective(CollectiveOpRecord record,MPIRequest::RequestType rtype)838 void EnqueueTensorCollective(CollectiveOpRecord record,
839                              MPIRequest::RequestType rtype) {
840   const Tensor* input_tensor = record.in_t;
841   MPIRequest message;
842   message.set_request_rank(record.rank);
843   message.set_tensor_name(record.name);
844   message.set_tensor_type(record.dtype);
845   message.set_request_type(rtype);
846   input_tensor->shape().AsProto(message.mutable_tensor_shape());
847 
848   mutex_lock guard(mpi_global.mu);
849   mpi_global.tensor_table.emplace(record.name, record);
850   mpi_global.message_queue.push(message);
851 }
852 
853 }  // namespace
854 
855 #if GOOGLE_CUDA
CudaStreamForMPI()856 cudaStream_t CudaStreamForMPI() { return mpi_global.stream; }
857 #endif
858 
859 // Op to initialize MPI in the current process. The settings used in the
860 // configuration are the same that must be used for all future MPI ops.
861 template <typename Device>
862 class MPIInitOp : public OpKernel {
863  public:
MPIInitOp(OpKernelConstruction * context)864   explicit MPIInitOp(OpKernelConstruction* context) : OpKernel(context) {}
865 
Compute(OpKernelContext * context)866   void Compute(OpKernelContext* context) override {
867     bool on_gpu = IsGPUDevice<Device>();
868     OP_REQUIRES_OK(context, InitializeMPIOnce(on_gpu));
869   }
870 };
871 
872 REGISTER_KERNEL_BUILDER(Name("MPIInit").Device(DEVICE_CPU),
873                         MPIInitOp<CPUDevice>);
874 #if GOOGLE_CUDA
875 REGISTER_KERNEL_BUILDER(Name("MPIInit").Device(DEVICE_GPU),
876                         MPIInitOp<GPUDevice>);
877 #endif
878 
879 // Op to get the current MPI Size.
880 template <typename Device>
881 class MPISizeOp : public OpKernel {
882  public:
MPISizeOp(OpKernelConstruction * context)883   explicit MPISizeOp(OpKernelConstruction* context) : OpKernel(context) {}
884 
Compute(OpKernelContext * context)885   void Compute(OpKernelContext* context) override {
886     OP_REQUIRES_OK(context, IsMPIInitialized());
887 
888     // Write integer to output tensor
889     Tensor* output;
890     OP_REQUIRES_OK(context,
891                    context->allocate_output(0, TensorShape({}), &output));
892 
893     auto flat = output->flat<int>();
894     flat(0) = mpi_global.size;
895   }
896 };
897 
898 REGISTER_KERNEL_BUILDER(Name("MPISize").Device(DEVICE_CPU),
899                         MPISizeOp<CPUDevice>);
900 #if GOOGLE_CUDA
901 REGISTER_KERNEL_BUILDER(Name("MPISize").Device(DEVICE_GPU).HostMemory("size"),
902                         MPISizeOp<GPUDevice>);
903 #endif
904 
905 // Op to get the current MPI Rank.
906 template <typename Device>
907 class MPIRankOp : public OpKernel {
908  public:
MPIRankOp(OpKernelConstruction * context)909   explicit MPIRankOp(OpKernelConstruction* context) : OpKernel(context) {}
910 
Compute(OpKernelContext * context)911   void Compute(OpKernelContext* context) override {
912     OP_REQUIRES_OK(context, IsMPIInitialized());
913 
914     // Write integer to output tensor
915     Tensor* output;
916     OP_REQUIRES_OK(context,
917                    context->allocate_output(0, TensorShape({}), &output));
918 
919     auto flat = output->flat<int>();
920     flat(0) = mpi_global.rank;
921   }
922 };
923 
924 REGISTER_KERNEL_BUILDER(Name("MPIRank").Device(DEVICE_CPU),
925                         MPIRankOp<CPUDevice>);
926 #if GOOGLE_CUDA
927 REGISTER_KERNEL_BUILDER(Name("MPIRank").Device(DEVICE_GPU).HostMemory("rank"),
928                         MPIRankOp<GPUDevice>);
929 #endif
930 
931 // Op to get the current local MPI Rank.
932 template <typename Device>
933 class MPILocalRankOp : public OpKernel {
934  public:
MPILocalRankOp(OpKernelConstruction * context)935   explicit MPILocalRankOp(OpKernelConstruction* context) : OpKernel(context) {}
936 
Compute(OpKernelContext * context)937   void Compute(OpKernelContext* context) override {
938     OP_REQUIRES_OK(context, IsMPIInitialized());
939 
940     // Write integer to output tensor
941     Tensor* output;
942     OP_REQUIRES_OK(context,
943                    context->allocate_output(0, TensorShape({}), &output));
944 
945     auto flat = output->flat<int>();
946     flat(0) = mpi_global.local_rank;
947   }
948 };
949 
950 REGISTER_KERNEL_BUILDER(Name("MPILocalRank").Device(DEVICE_CPU),
951                         MPILocalRankOp<CPUDevice>);
952 #if GOOGLE_CUDA
953 REGISTER_KERNEL_BUILDER(
954     Name("MPILocalRank").Device(DEVICE_GPU).HostMemory("rank"),
955     MPILocalRankOp<GPUDevice>);
956 #endif
957 
958 template <typename Device>
959 class MPIAllreduceOp : public AsyncOpKernel {
960  public:
MPIAllreduceOp(OpKernelConstruction * context)961   explicit MPIAllreduceOp(OpKernelConstruction* context)
962       : AsyncOpKernel(context) {}
963 
964   // Although this op is handled asynchronously, the ComputeAsync call is
965   // very inexpensive. It only sets up a CollectiveOpRecord and places it
966   // in the table for the background thread to handle. Thus, we do not need
967   // a TF pool thread to perform the op.
IsExpensive()968   bool IsExpensive() override { return false; }
969 
ComputeAsync(OpKernelContext * context,DoneCallback done)970   void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
971     OP_REQUIRES_OK_ASYNC(context, IsMPIInitialized(), done);
972     const Tensor* input_tensor = &context->input(0);
973     Tensor* output_tensor;
974     OP_REQUIRES_OK_ASYNC(
975         context,
976         context->allocate_output(0, input_tensor->shape(), &output_tensor),
977         done);
978 
979     // Record allocated on stack so op can fail without memory leak
980     CollectiveOpRecord record;
981     record.name = name();
982     record.context = context;
983     record.in_t = input_tensor;
984     record.out_t = output_tensor;
985     record.on_gpu = IsGPUDevice<Device>();
986     record.dtype = input_tensor->dtype();
987 
988     const size_t temp_size =
989         (input_tensor->NumElements() + mpi_global.size - 1) / mpi_global.size;
990     TensorShape temp_shape;
991     temp_shape.AddDim(temp_size);
992     OP_REQUIRES_OK_ASYNC(context,
993                          context->allocate_temp(input_tensor->dtype(),
994                                                 temp_shape, &record.temp_t),
995                          done);
996 
997     auto allreduce_done_callback = [done, context](StatusOr<Tensor> status) {
998       context->SetStatus(status.status());
999       done();
1000     };
1001     record.callback = allreduce_done_callback;
1002 
1003     auto allreduce_launch_callback = [record] {
1004       EnqueueTensorCollective(record, MPIRequest::ALLREDUCE);
1005     };
1006 
1007     // If we are on a CPU, our device context will be null and we can't
1008     // get a stream to enqueue this on. On a CPU this op is called when the
1009     // data is already available, so we can just immediately do the
1010     // allreduce; we don't have to wait for the data to get populated.
1011 #if GOOGLE_CUDA
1012     auto device_context = context->op_device_context();
1013     if (device_context == nullptr) {
1014       allreduce_launch_callback();
1015     } else {
1016       auto stream = device_context->stream();
1017       stream->ThenDoHostCallback(allreduce_launch_callback);
1018     }
1019 #else
1020     allreduce_launch_callback();
1021 #endif
1022   }
1023 };
1024 
1025 REGISTER_KERNEL_BUILDER(Name("MPIAllreduce").Device(DEVICE_CPU),
1026                         MPIAllreduceOp<CPUDevice>);
1027 #if GOOGLE_CUDA
1028 REGISTER_KERNEL_BUILDER(Name("MPIAllreduce").Device(DEVICE_GPU),
1029                         MPIAllreduceOp<GPUDevice>);
1030 #endif
1031 
1032 template <typename Device>
1033 class MPIAllgatherOp : public AsyncOpKernel {
1034  public:
MPIAllgatherOp(OpKernelConstruction * context)1035   explicit MPIAllgatherOp(OpKernelConstruction* context)
1036       : AsyncOpKernel(context) {}
1037 
1038   // Although this op is handled asynchronously, the ComputeAsync call is
1039   // very inexpensive. It only sets up a CollectiveOpRecord and places it
1040   // in the table for the background thread to handle. Thus, we do not need
1041   // a TF pool thread to perform the op.
IsExpensive()1042   bool IsExpensive() override { return false; }
1043 
ComputeAsync(OpKernelContext * context,DoneCallback done)1044   void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
1045     OP_REQUIRES_OK_ASYNC(context, IsMPIInitialized(), done);
1046     const Tensor* input_tensor = &context->input(0);
1047     const Tensor* sizing_tensor = &context->input(1);
1048 
1049     // Record allocated on stack so op can fail without memory leak
1050     CollectiveOpRecord record;
1051     record.name = name();
1052     record.context = context;
1053     record.in_t = input_tensor;
1054     record.on_gpu = IsGPUDevice<Device>();
1055 
1056     // Construct the output size from the sizing tensor
1057     size_t output_first_dim = 0;
1058     if (sizing_tensor->shape().dims() == 0) {
1059       // 0-dim sizing_tensor implies that the op is just gathering
1060       // a single element from each rank
1061       output_first_dim = mpi_global.size;
1062       for (int i = 0; i < mpi_global.size; i++) {
1063         record.sizes_vec.push_back(1);
1064       }
1065     } else {
1066       // Collect the total output tensor sizing from the sizing tensor
1067       // NOTE: The sizing tensor is forced to be placed on the CPU by
1068       // declaring the input as HostMemory, so it is valid to read it here.
1069       const int64* sizing_array =
1070           (const int64*)sizing_tensor->tensor_data().data();
1071       for (int i = 0; i < mpi_global.size; i++) {
1072         record.sizes_vec.push_back(sizing_array[i]);
1073         output_first_dim += sizing_array[i];
1074       }
1075     }
1076 
1077     TensorShape output_shape;
1078     output_shape.AddDim(output_first_dim);
1079     for (int i = 1; i < input_tensor->shape().dims(); i++) {
1080       output_shape.AddDim(input_tensor->shape().dim_size(i));
1081     }
1082 
1083     Tensor* output_tensor;
1084     OP_REQUIRES_OK_ASYNC(
1085         context, context->allocate_output(0, output_shape, &output_tensor),
1086         done);
1087 
1088     record.out_t = output_tensor;
1089     record.dtype = input_tensor->dtype();
1090 
1091     auto allgather_done_callback = [done, context](StatusOr<Tensor> status) {
1092       context->SetStatus(status.status());
1093       done();
1094     };
1095     record.callback = allgather_done_callback;
1096 
1097     auto allgather_launch_callback = [record] {
1098       EnqueueTensorCollective(record, MPIRequest::ALLGATHER);
1099     };
1100 
1101     // If we are on a CPU, our device context will be null and we can't
1102     // get a stream to enqueue this on. On a CPU this op is called when the
1103     // data is already available, so we can just immediately do the
1104     // allgather; we don't have to wait for the data to get populated.
1105 #if GOOGLE_CUDA
1106     auto device_context = context->op_device_context();
1107     if (device_context == nullptr) {
1108       allgather_launch_callback();
1109     } else {
1110       auto stream = device_context->stream();
1111       stream->ThenDoHostCallback(allgather_launch_callback);
1112     }
1113 #else
1114     allgather_launch_callback();
1115 #endif
1116   }
1117 };
1118 
1119 REGISTER_KERNEL_BUILDER(
1120     Name("MPIAllgather").Device(DEVICE_CPU).HostMemory("sizes"),
1121     MPIAllgatherOp<CPUDevice>);
1122 #if GOOGLE_CUDA
1123 REGISTER_KERNEL_BUILDER(
1124     Name("MPIAllgather").Device(DEVICE_GPU).HostMemory("sizes"),
1125     MPIAllgatherOp<GPUDevice>);
1126 #endif
1127 
1128 }  // namespace mpi_collectives
1129 }  // namespace contrib
1130 }  // namespace tensorflow
1131 
1132 #endif  // TENSORFLOW_USE_MPI
1133